schedule_batch.py 69.6 KB
Newer Older
1
2
from __future__ import annotations

3
4
import enum

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
34
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import copy
37
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
38
import logging
39
import threading
40
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from enum import Enum, auto
42
from http import HTTPStatus
43
from itertools import chain
44
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
47
import torch
48

Liangsheng Yin's avatar
Liangsheng Yin committed
49
from sglang.global_config import global_config
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.disaggregation.utils import DisaggregationMode
56
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Hanming Lu's avatar
Hanming Lu committed
57
58
59
60
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
61
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
62
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
63
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
Yi Zhang's avatar
Yi Zhang committed
64
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
65
from sglang.srt.mem_cache.radix_cache import RadixKey
Hanming Lu's avatar
Hanming Lu committed
66
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
67
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
69
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
70
from sglang.srt.sampling.sampling_params import SamplingParams
71
from sglang.srt.server_args import ServerArgs
72
from sglang.srt.utils import flatten_nested_list
Liangsheng Yin's avatar
Liangsheng Yin committed
73

74
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
75
    from sglang.srt.configs.model_config import ModelConfig
76
    from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
77

Liangsheng Yin's avatar
Liangsheng Yin committed
78
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
79

80
81
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
82
    "mm_attention_backend",
83
84
85
86
87
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
88
    "disable_flashinfer_cutlass_moe_fp4_allgather",
89
90
    "disable_radix_cache",
    "enable_dp_lm_head",
91
    "enable_fp32_lm_head",
92
    "flashinfer_mxfp4_moe_precision",
93
    "enable_flashinfer_allreduce_fusion",
94
95
96
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
    "ep_num_redundant_experts",
97
98
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
99
    "pp_max_micro_batch_size",
100
101
102
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
103
    "speculative_accept_threshold_acc",
104
    "speculative_attention_mode",
105
106
    "torchao_config",
    "triton_attention_reduce_in_fp32",
107
    "num_reserved_decode_tokens",
108
    "weight_loader_disable_mmap",
109
    "enable_multimodal",
110
    "enable_symm_mem",
Lianmin Zheng's avatar
Lianmin Zheng committed
111
    "enable_custom_logit_processor",
112
    "disaggregation_mode",
113
    "enable_deterministic_inference",
fzyzcjy's avatar
fzyzcjy committed
114
115
    "nsa_prefill",
    "nsa_decode",
116
    "multi_item_scoring_delimiter",
117
118
]

119
# Put some global args for easy access
120
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
121

Ying Sheng's avatar
Ying Sheng committed
122
123
124
logger = logging.getLogger(__name__)


125
126
127
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
128

129
    def to_json(self):
130
        raise NotImplementedError()
131
132
133


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
134
    def __init__(self, matched: Union[int, List[int]]):
135
136
137
        super().__init__()
        self.matched = matched

138
139
140
141
142
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
143
144


145
146
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
147
        super().__init__()
148
        self.matched = matched
149

150
151
152
153
154
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
155
156


157
158
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
159
        super().__init__()
160
        self.length = length
161

162
163
164
165
166
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
167
168
169


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
170
    def __init__(self, message=None, status_code=None, err_type=None):
171
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
172
        self.message = message or "Aborted"
173
174
        self.status_code = status_code
        self.err_type = err_type
175

176
177
178
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
179
            "message": self.message,
180
181
            "status_code": self.status_code,
            "err_type": self.err_type,
182
        }
183

Lianmin Zheng's avatar
Lianmin Zheng committed
184

Mick's avatar
Mick committed
185
186
187
188
189
190
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

191
192
193
194
195
196
197
198
199
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

200
201
202
203
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
204

205
@dataclasses.dataclass
Mick's avatar
Mick committed
206
207
class MultimodalDataItem:
    """
208
209
210
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
211

212
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
213
    """
214

Mick's avatar
Mick committed
215
216
217
    modality: Modality
    hash: int = None
    pad_value: int = None
218
    offsets: Optional[list] = None
Mick's avatar
Mick committed
219

220
221
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
222
223
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
239

240
241
242
243
244
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
245

246
247
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
248

Mick's avatar
Mick committed
249
250
251
252
253
254
255
256
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
257
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
258
        """
259
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
260

261
        if self.hash is None:
262
263
            if self.feature is not None:
                hashed_feature = self.feature
264
            else:
265
                hashed_feature = self.precomputed_embeddings
266
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
267
268
269
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

270
271
272
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
273
    def is_audio(self):
274
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
275
276

    def is_image(self):
277
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
278
279

    def is_video(self):
280
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
281

282
283
284
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
285
286
287
288
    def validate(self):
        ...
        # TODO

289
290
291
292
293
294
295
296
297
298
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

299
    def merge(self, other):
300
        self.feature += other.feature
301
        self.offsets += other.offsets
302
303
304
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
305
306
307
308
309
310
311

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
312
    image_pad_len: Optional[list] = None
313
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
314

Mick's avatar
Mick committed
315
    # image
Mick's avatar
Mick committed
316
    im_token_id: Optional[int] = None
317
318
319
320
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
321
322
323

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
324

Mick's avatar
Mick committed
325
    # audio
326
327
328
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
329

330
331
332
333
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
334
    @staticmethod
335
    def from_dict(obj: dict):
Mick's avatar
Mick committed
336
        ret = MultimodalInputs(
Mick's avatar
Mick committed
337
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
338
        )
339

Mick's avatar
Mick committed
340
        assert isinstance(ret.mm_items, list)
341
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
342
343
        for item in ret.mm_items:
            item.set_pad_value()
344
345

        optional_args = [
346
347
            "mrope_positions",
            "mrope_position_delta",
348
            "im_token_id",
Mick's avatar
Mick committed
349
350
            "im_start_id",
            "im_end_id",
351
            "video_token_id",
Mick's avatar
Mick committed
352
353
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
354
355
            "audio_start_id",
            "audio_end_id",
356
            "audio_token_id",
357
358
359
360
361
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
        return ret

Mick's avatar
Mick committed
364
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
365
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
366

367
368
369
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
370
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
371
372
        return any(item.is_audio() for item in self.mm_items)

373
374
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
375
376

    def merge(self, other: MultimodalInputs):
377
378
379
        """
        merge image inputs when requests are being merged
        """
380

381
        # args needed to be merged
382
        optional_args = [
Mick's avatar
Mick committed
383
            "mm_items",
384
            "image_pad_len",
385
386
        ]
        for arg in optional_args:
387
388
389
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
390
391
392
393
394
395
396
397
398
399

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

400
401
402
403
404
405
406
407
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
408
409
410
411
412
413

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
414
        # other args would be kept intact
415

Liangsheng Yin's avatar
Liangsheng Yin committed
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class RequestStage(str, enum.Enum):
    # prefill
    PREFILL_WAITING = "prefill_waiting"

    # disaggregation prefill
    PREFILL_PREPARE = "prefill_prepare"
    PREFILL_BOOTSTRAP = "prefill_bootstrap"
    PREFILL_FORWARD = "prefill_forward"
    PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

    # disaggregation decode
    DECODE_PREPARE = "decode_prepare"
    DECODE_BOOTSTRAP = "decode_bootstrap"
    DECODE_WAITING = "decode_waiting"
    DECODE_TRANSFERRED = "decode_transferred"


Lianmin Zheng's avatar
Lianmin Zheng committed
434
class Req:
435
    """The input and output status of a request."""
436

437
438
439
440
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
441
        origin_input_ids: List[int],
442
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
445
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
446
        stream: bool = False,
447
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
448
        lora_id: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
449
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
450
        token_type_ids: List[int] = None,
451
        session_id: Optional[str] = None,
452
        custom_logit_processor: Optional[str] = None,
453
        return_hidden_states: bool = False,
454
        eos_token_ids: Optional[Set[int]] = None,
455
        bootstrap_host: Optional[str] = None,
456
        bootstrap_port: Optional[int] = None,
457
        bootstrap_room: Optional[int] = None,
458
        disagg_mode: Optional[DisaggregationMode] = None,
459
        data_parallel_rank: Optional[int] = None,
460
        vocab_size: Optional[int] = None,
461
        priority: Optional[int] = None,
462
        metrics_collector: Optional[SchedulerMetricsCollector] = None,
463
        extra_key: Optional[str] = None,
464
    ):
465
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
466
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
467
        self.origin_input_text = origin_input_text
468
469
470
471
472
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
473
        self.origin_input_ids = origin_input_ids
474
475
476
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
477
        self.fill_ids = []
478
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
479
        self.input_embeds = input_embeds
480

woodx's avatar
woodx committed
481
482
483
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
484
485
486
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
487
        # Sampling info
488
489
490
491
492
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
493
        self.sampling_params = sampling_params
494
        self.custom_logit_processor = custom_logit_processor
495
        self.return_hidden_states = return_hidden_states
496

497
        # extra key for classifying the request (e.g. cache_salt)
498
499
500
501
502
503
        if lora_id is not None:
            extra_key = (
                extra_key or ""
            ) + lora_id  # lora_id is concatenated to the extra key

        self.extra_key = extra_key
504
        self.lora_id = lora_id
Liangsheng Yin's avatar
Liangsheng Yin committed
505

506
        # Memory pool info
507
        self.req_pool_idx: Optional[int] = None
508

509
510
511
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
        # Whether this request has finished output
        self.finished_output = None
514
515
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
516
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
517
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
518
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
519
        self.stream = stream
520
        self.eos_token_ids = eos_token_ids
521
        self.vocab_size = vocab_size
522
        self.priority = priority
523

524
        # For incremental decoding
525
526
527
528
529
530
531
532
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
533
534
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
535
        self.decoded_text = ""
536

537
        # For multimodal inputs
Mick's avatar
Mick committed
538
        self.multimodal_inputs: Optional[MultimodalInputs] = None
539

540
        # Prefix info
541
        # The indices to kv cache for the shared prefix.
542
        self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
543
        # Number of tokens to run prefill.
544
        self.extend_input_len = 0
545
546
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
547
548
549
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
550
551
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Ke Bao's avatar
Ke Bao committed
552
553
        # The prefix length of the last prefix matching
        self.last_matched_prefix_len: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
554

555
556
557
558
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
559

560
561
562
        # For retraction
        self.is_retracted = False

563
564
565
566
567
568
569
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

570
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
571
        self.return_logprob = return_logprob
572
        # Start index to compute logprob from.
573
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
574
        self.top_logprobs_num = top_logprobs_num
575
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
578

579
        # Logprobs (return values)
580
581
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
582
583
584
585
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
586
587
588
589
590
591
592
593
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
594
595

        if return_logprob:
596
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
599
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
602
603
604
605
            # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
            self.output_token_ids_logprobs_val: List[
                Union[List[float], torch.Tensor]
            ] = []
606
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
609
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
610
611
612
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
613
        self.hidden_states: List[List[float]] = []
614
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
615
616
        self.output_topk_p = None
        self.output_topk_index = None
617

618
        # Embedding (return values)
619
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
620

621
        # Constrained decoding
622
        self.grammar: Optional[BaseGrammarObject] = None
623
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
624

625
        # The number of cached tokens that were already cached in the KV cache
626
        self.cached_tokens = 0
627
        self.already_computed = 0
628

629
630
631
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
632
633

        # For metrics
634
        self.metrics_collector = metrics_collector
635
        self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
636
        self.has_log_time_stats: bool = False
637
        self.last_tic = time.monotonic()
638

Byron Hsu's avatar
Byron Hsu committed
639
        # For disaggregation
640
        self.bootstrap_host: str = bootstrap_host
641
        self.bootstrap_port: Optional[int] = bootstrap_port
642
        self.bootstrap_room: Optional[int] = bootstrap_room
643
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
644

645
646
647
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
648
649
650
651
652
653
654
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

655
656
657
658
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
659
        self.metadata_buffer_index: int = -1
660

661
662
663
664
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

665
666
667
    @property
    def is_prefill_only(self) -> bool:
        """Check if this request is prefill-only (no token generation needed)."""
668
        # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
670
671
672
673
        from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

        spec_alg = global_server_args_dict["speculative_algorithm"]
        return self.sampling_params.max_new_tokens == 0 and (
            spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
674
        )
675

676
677
678
    def add_latency(self, stage: RequestStage):
        if self.metrics_collector is None:
            return
679

680
        now = time.monotonic()
681
        self.metrics_collector.observe_per_stage_req_latency(
682
683
684
685
            stage.value, now - self.last_tic
        )
        self.last_tic = now

686
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
687
688
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
689
        else:
Mick's avatar
Mick committed
690
            self.multimodal_inputs.merge(image_inputs)
691

692
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
693
        # Whether request reached finished condition
694
695
        return self.finished_reason is not None

696
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
697
        self.fill_ids = self.origin_input_ids + self.output_ids
698
699
700
701
702
703
704
705
        input_len = len(self.fill_ids)
        # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
        max_prefix_len = input_len - 1
        if self.return_logprob:
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
        max_prefix_len = max(max_prefix_len, 0)
        token_ids = self.fill_ids[:max_prefix_len]

706
        if tree_cache is not None:
707
708
709
710
711
712
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
713
                key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
714
            )
Ke Bao's avatar
Ke Bao committed
715
            self.last_matched_prefix_len = len(self.prefix_indices)
716
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
717

Liangsheng Yin's avatar
Liangsheng Yin committed
718
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
719
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
720
721
722
723
724
725
726
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )
727
728
729
730
731
732
733
            self.surr_and_decode_ids = (
                self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
            )
            self.cur_decode_ids_len = len(self.output_ids)
        else:
            self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
            self.cur_decode_ids_len = len(self.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
734

735
        return self.surr_and_decode_ids, self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
736

ybyang's avatar
ybyang committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
    def tail_str(self) -> str:
        tail_len = self.sampling_params.stop_str_max_len + 1
        tail_len = min(tail_len, len(self.output_ids))
        return self.tokenizer.decode(self.output_ids[-tail_len:])

    def check_match_stop_str_prefix(self) -> bool:
        """
        Check if the suffix of tail_str overlaps with any stop_str prefix
        """
        if not self.sampling_params.stop_strs:
            return False

        tail_str = self.tail_str()

        # Early return if tail_str is empty
        if not tail_str:
            return False

        for stop_str in self.sampling_params.stop_strs:
            if not stop_str:
                continue
            # Check if stop_str is contained in tail_str (fastest check first)
            if stop_str in tail_str:
                return True

            # Check if tail_str suffix matches stop_str prefix
            # Only check if stop_str is not empty, it's for stream output
            min_len = min(len(tail_str), len(stop_str))
            for i in range(1, min_len + 1):
                if tail_str[-i:] == stop_str[:i]:
                    return True

        return False

771
    def check_finished(self):
772
        if self.finished():
773
774
            return

775
        if self.to_abort:
776
777
778
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
779
780
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
781
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
782
783
784
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
785
786
            return

787
788
789
790
791
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

792
        last_token_id = self.output_ids[-1]
793

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
811

812
813
814
815
816
817
818
819
        if last_token_id > self.vocab_size or last_token_id < 0:
            if self.sampling_params.stop_token_ids:
                self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
            if self.eos_token_ids:
                self.output_ids[-1] = next(iter(self.eos_token_ids))
            self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
            return

820
        # Check stop strings
821
        if len(self.sampling_params.stop_strs) > 0:
ybyang's avatar
ybyang committed
822
            tail_str = self.tail_str()
823
824

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
825
                if stop_str in tail_str or stop_str in self.decoded_text:
826
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
827
828
                    return

829
    def reset_for_retract(self):
830
        self.prefix_indices = torch.empty((0,), dtype=torch.int64)
831
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
832
        self.swa_uuid_for_lock = None
833
834
        self.extend_input_len = 0
        self.is_retracted = True
835
836
837
838
839
840
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
841
        self.already_computed = 0
842

Lianmin Zheng's avatar
Lianmin Zheng committed
843
844
845
846
847
848
849
850
851
852
853
854
855
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

856
857
858
859
860
861
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
862
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
863
        else:
864
865
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
        logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
866
867
        self.has_log_time_stats = True

868
869
870
871
872
873
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
874
        self.return_logprob = False
875
876
877
878
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
879
    def __repr__(self):
880
        return (
881
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
882
883
884
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
885
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
886
887


888
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
889
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
890
    """Store all information of a batch on the scheduler."""
891

892
    # Request, memory pool, and cache
893
    reqs: List[Req]
894
    req_to_token_pool: ReqToTokenPool = None
895
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
896
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
897
    is_hybrid: bool = False
898

899
    # Batch configs
900
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
901
    forward_mode: ForwardMode = None
902
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
903
904
905
906
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
907

908
909
910
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

911
    # Sampling info
912
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
913

914
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
915
    input_ids: torch.Tensor = None  # shape: [b], int64
916
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
917
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
918
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
919
    seq_lens: torch.Tensor = None  # shape: [b], int64
920
    seq_lens_cpu: torch.Tensor = None  # shape: [b], int64
921
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
924

925
926
927
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

928
929
    # The sum of all sequence lengths
    seq_lens_sum: int = None
930
931
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: torch.Tensor = None  # shape: [b], int32
932

Ke Bao's avatar
Ke Bao committed
933
934
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
935
    global_num_tokens_for_logprob: Optional[List[int]] = None
936
    is_extend_in_batch: bool = False
937
    can_run_dp_cuda_graph: bool = False
938
939
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
940

941
    # For processing logprobs
942
    return_logprob: bool = False
943
    top_logprobs_nums: Optional[List[int]] = None
944
    token_ids_logprobs: Optional[List[List[int]]] = None
945

Lianmin Zheng's avatar
Lianmin Zheng committed
946
947
948
949
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

950
951
952
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
953
    extend_num_tokens: Optional[int] = None
954
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
955
    extend_logprob_start_lens: List[int] = None
956
957
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
958

Lianmin Zheng's avatar
Lianmin Zheng committed
959
    # For encoder-decoder architectures
960
961
962
963
964
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

965
966
967
    # Stream
    has_stream: bool = False

968
969
    # Has grammar
    has_grammar: bool = False
970

971
    # Device
972
973
    device: str = "cuda"

974
    # Speculative decoding
975
    spec_algorithm: SpeculativeAlgorithm = None
976
977
    # spec_info: Optional[SpecInput] = None
    spec_info: Optional[SpecInput] = None
978

979
980
981
    # Whether to return hidden states
    return_hidden_states: bool = False

982
983
984
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

985
    # hicache pointer for synchronizing data loading from CPU to GPU
986
    hicache_consumer_index: int = -1
987

988
    @classmethod
989
990
    def init_new(
        cls,
991
        reqs: List[Req],
992
        req_to_token_pool: ReqToTokenPool,
993
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
994
995
996
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
997
        spec_algorithm: SpeculativeAlgorithm,
998
        chunked_req: Optional[Req] = None,
999
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1000
1001
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
1002
1003
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1004
1005
1006
1007
            assert (
                tree_cache is None
                or isinstance(tree_cache, SWARadixCache)
                or isinstance(tree_cache, SWAChunkCache)
Hanming Lu's avatar
Hanming Lu committed
1008
1009
1010
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

1011
1012
1013
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
1014
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
1015
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
1016
            is_hybrid=is_hybrid,
1017
            model_config=model_config,
1018
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
            return_logprob=return_logprob,
1020
            has_stream=any(req.stream for req in reqs),
1021
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
1022
            device=req_to_token_pool.device,
1023
            spec_algorithm=spec_algorithm,
1024
            return_hidden_states=any(req.return_hidden_states for req in reqs),
1025
            is_prefill_only=all(req.is_prefill_only for req in reqs),
1026
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
1027
1028
        )

1029
    def batch_size(self):
1030
        return len(self.reqs)
1031

Lianmin Zheng's avatar
Lianmin Zheng committed
1032
1033
1034
    def is_empty(self):
        return len(self.reqs) == 0

1035
1036
1037
1038
1039
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1040
            im = req.multimodal_inputs
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1052
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1065
                # NOTE: the encoder part should be considered as a whole
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1083
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1084
1085
            self.device, non_blocking=True
        )
1086
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1087
1088
            self.device, non_blocking=True
        )
1089
        self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1090
1091

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1093
1094
1095
1096
1097
1098
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1099
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1100
1101
1102
1103
1104
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1105
1106
1107
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1108

1109
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1110
1111
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1112
        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
        reqs = self.reqs
1114
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1115
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1116
        seq_lens = [len(r.fill_ids) for r in reqs]
1117
        orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1118
1119
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1120

woodx's avatar
woodx committed
1121
1122
1123
1124
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

1125
1126
1127
        input_ids_tensor = torch.tensor(
            list(chain.from_iterable(input_ids)), dtype=torch.int64
        ).to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1131
        seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1132
1133
1134
        orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
woodx's avatar
woodx committed
1135
1136
1137
1138
1139
1140
1141

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

1142
1143
1144
1145
1146
1147
        # Set batch fields needed by alloc_for_extend
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
        self.seq_lens = seq_lens_tensor
        self.seq_lens_cpu = seq_lens_cpu
        self.extend_num_tokens = extend_num_tokens
1148
1149

        # Allocate memory
1150
1151
        out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
            self
1152
1153
1154
        )

        # Set fields
Rin Intachuen's avatar
Rin Intachuen committed
1155
        input_embeds = []
1156
        extend_input_logprob_token_ids = []
1157
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1158

Lianmin Zheng's avatar
Lianmin Zheng committed
1159
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1160
            req.req_pool_idx = req_pool_indices[i]
1161
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1162

Rin Intachuen's avatar
Rin Intachuen committed
1163
1164
1165
1166
1167
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1168
1169
            multimodal_inputs.append(req.multimodal_inputs)

1170
1171
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1172
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1173

1174
            # Compute the relative logprob_start_len in an extend batch
1175
1176
1177
1178
1179
1180
1181
1182
            #
            # Key variables:
            # - logprob_start_len: Absolute position in full sequence where logprob computation begins
            # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
            # - extend_input_len: Number of tokens that need to be processed in this extend batch
            #   (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
            #    and prefix_indices are the cached/shared prefix tokens)
            #
1183
            if req.logprob_start_len >= pre_len:
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
                # Optimization for prefill-only requests: When we only need logprobs at
                # positions beyond the input sequence (to score next-token likelihood), skip all
                # input logprob computation during prefill since no generation will occur.
                if self.is_prefill_only and req.logprob_start_len == len(
                    req.origin_input_ids
                ):
                    # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
                    req.extend_logprob_start_len = req.extend_input_len
                else:
                    # Convert absolute logprob_start_len to relative extend_logprob_start_len
                    #
                    # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
                    # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
                    # This means: "compute logprobs from position 3 onwards in extend batch"
                    req.extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len,
                        req.extend_input_len,
                        req.seqlen - 1,
                    )
1203
            else:
1204
                # logprob_start_len is before the current extend batch, so start from beginning
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1251

Lianmin Zheng's avatar
Lianmin Zheng committed
1252
1253
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
1254
        self.orig_seq_lens = orig_seq_lens_tensor
Lianmin Zheng's avatar
Lianmin Zheng committed
1255
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1256
1257
1258
1259
1260
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1261
1262
1263
1264
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1265
                pixel_values = getattr(mm_item, "feature", None)
1266
                if isinstance(pixel_values, torch.Tensor):
1267
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1268
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1269
        self.token_type_ids = token_type_ids_tensor
1270
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1271

1272
1273
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1274
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1275

1276
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1277
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1278

1279
1280
1281
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1282
        # Build sampling info
1283
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1284
1285
            self,
            self.model_config.vocab_size,
1286
        )
1287

1288
1289
1290
1291
1292
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1293
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1294
        self.forward_mode = ForwardMode.MIXED
1295
        running_bs = running_batch.batch_size()
1296
1297
1298
1299
1300

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1301
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1302
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1303

1304
        self.merge_batch(running_batch)
1305
1306
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1307

1308
1309
1310
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1311
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1312
        self.prefix_lens.extend(
1313
            [
1314
                len(r.origin_input_ids) + len(r.output_ids) + delta
1315
1316
1317
                for r in running_batch.reqs
            ]
        )
1318
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
1320
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1321
        self.extend_logprob_start_lens.extend([0] * running_bs)
1322

1323
    def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1324
        page_size = self.token_to_kv_pool_allocator.page_size
1325
1326
1327
1328
1329
        requests = (
            self.reqs
            if selected_indices is None
            else [self.reqs[i] for i in selected_indices]
        )
1330
        if page_size == 1:
1331
            return len(requests)
1332
1333
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1334
        return (
1335
            sum(1 for req in requests if req.seqlen % page_size == 0)
pansicheng's avatar
pansicheng committed
1336
            if self.enable_overlap
1337
            else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
pansicheng's avatar
pansicheng committed
1338
        )
1339

1340
1341
1342
    def check_decode_mem(
        self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
    ):
Hanming Lu's avatar
Hanming Lu committed
1343
        num_tokens = (
1344
            self.new_page_count_next_decode(selected_indices)
1345
1346
1347
1348
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1349
1350
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1351

1352
    def retract_decode(self, server_args: ServerArgs):
1353
        """Retract the decoding requests when there is not enough memory."""
1354
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1355
1356

        # TODO(lsyin): improve retraction policy for radix cache
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1370
1371
        retracted_reqs = []
        first_iter = True
1372
1373
        while first_iter or (
            not self.check_decode_mem(selected_indices=sorted_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
1374
1375
1376
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1391
1392
                break

1393
            first_iter = False
1394
1395
1396
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)
1397
            self.release_req(idx, len(sorted_indices), server_args)
Liangsheng Yin's avatar
Liangsheng Yin committed
1398

1399
1400
1401
1402
1403
1404
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1405
        self.filter_batch(keep_indices=sorted_indices)
1406

Liangsheng Yin's avatar
Liangsheng Yin committed
1407
1408
1409
1410
1411
1412
1413
1414
1415
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

1416
        return retracted_reqs, new_estimate_ratio, []
1417

1418
1419
    def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
        req = self.reqs[idx]
1420
        seq_lens_cpu = self.seq_lens_cpu.numpy()
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455

        if server_args.disaggregation_mode == "decode":
            req.offload_kv_cache(
                self.req_to_token_pool, self.token_to_kv_pool_allocator
            )
        if isinstance(self.tree_cache, ChunkCache):
            # ChunkCache does not have eviction
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)
        else:
            # TODO: apply more fine-grained retraction
            last_uncached_pos = (
                len(req.prefix_indices) // server_args.page_size
            ) * server_args.page_size
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)

            # release the last node
            if self.is_hybrid:
                self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
            else:
                self.tree_cache.dec_lock_ref(req.last_node)

            # NOTE(lsyin): we should use the newly evictable memory instantly.
            num_tokens = remaing_req_count * global_config.retract_decode_steps
            self._evict_tree_cache_if_needed(num_tokens)

        req.reset_for_retract()

1456
1457
1458
1459
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1460
1461
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1462
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1463
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1464
        self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1465
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1466
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1467
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1468
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1469
        self.extend_num_tokens = 0
1470
1471
1472
1473
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1474

1475
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1476
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1477
1478
        bs = len(self.reqs)

1479
1480
1481
        if (
            self.spec_algorithm.is_eagle()
            or self.spec_algorithm.is_standalone()
1482
            or self.spec_algorithm.is_ngram()
1483
        ):
1484
1485
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1486
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1487

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1511
        # Update fields
1512
1513
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1514

1515
1516
1517
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_decode()

1518
1519
1520
1521
        # Allocate memory
        self.out_cache_loc = alloc_for_decode(self, token_per_req=1)

        # Update seq_lens after allocation
1522
        if self.enable_overlap:
1523
1524
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
1525
            self.seq_lens_cpu = self.seq_lens_cpu + 1
1526
            self.orig_seq_lens = self.orig_seq_lens + 1
1527
1528
1529
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1530
            self.seq_lens_cpu.add_(1)
1531
            self.orig_seq_lens.add_(1)
1532
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1533

1534
1535
    def filter_batch(
        self,
1536
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1537
1538
1539
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1540
1541
1542
1543
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1544
1545
1546
            keep_indices = [
                i
                for i in range(len(self.reqs))
1547
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1548
                and self.reqs[i] not in chunked_req_to_exclude
1549
1550
1551
            ]

        if keep_indices is None or len(keep_indices) == 0:
1552
1553
1554
1555
            # Filter out all requests
            self.reqs = []
            return

1556
        if len(keep_indices) == len(self.reqs):
1557
1558
1559
            # No need to filter
            return

1560
1561
1562
1563
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1564
        if self.model_config.is_encoder_decoder:
1565
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1566
1567
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1568
        self.reqs = [self.reqs[i] for i in keep_indices]
1569
1570
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1571
1572
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1573
        self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1574
        self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1575
        self.out_cache_loc = None
1576
        self.seq_lens_sum = self.seq_lens.sum().item()
1577
        self.output_ids = self.output_ids[keep_indices_device]
1578
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1579
        if self.return_logprob:
1580
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1581
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1582
1583
        else:
            self.top_logprobs_nums = None
1584
            self.token_ids_logprobs = None
1585

1586
        self.has_stream = any(req.stream for req in self.reqs)
1587
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1588

1589
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1590
        if self.spec_info:
1591
1592
1593
1594
1595
1596
1597
1598
            if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
                has_been_filtered = False
            else:
                has_been_filtered = True
            self.spec_info.filter_batch(
                new_indices=keep_indices_device,
                has_been_filtered=has_been_filtered,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1599

1600
    def merge_batch(self, other: "ScheduleBatch"):
1601
1602
1603
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1604
        self.sampling_info.merge_batch(other.sampling_info)
1605

1606
1607
1608
1609
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1610
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1611
1612
            [self.req_pool_indices, other.req_pool_indices]
        )
1613
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1614
        self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1615
        self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1616
        self.out_cache_loc = None
1617
        self.seq_lens_sum += other.seq_lens_sum
1618
        if self.output_ids is not None:
1619
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1620
1621
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1622
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1623
1624
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1625
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1626
1627
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1628
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1629
        self.reqs.extend(other.reqs)
1630
1631
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1632

1633
1634
1635
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1636
        self.return_hidden_states |= other.return_hidden_states
1637

1638
1639
1640
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1641
1642
1643
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1644
        if self.forward_mode.is_decode_or_idle():
1645
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1646
1647
1648
1649
1650
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1651
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1652
1653
1654
1655
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1656

Lianmin Zheng's avatar
Lianmin Zheng committed
1657
        seq_lens_cpu = (
1658
            seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
1659
1660
        )

1661
1662
1663
1664
1665
        return ModelWorkerBatch(
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
1666
            orig_seq_lens=self.orig_seq_lens,
1667
            out_cache_loc=self.out_cache_loc,
1668
            seq_lens_cpu=seq_lens_cpu,
1669
            seq_lens_sum=self.seq_lens_sum,
1670
1671
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1672
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1673
            global_num_tokens=self.global_num_tokens,
1674
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1675
            is_extend_in_batch=self.is_extend_in_batch,
1676
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1677
1678
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1679
            extend_num_tokens=self.extend_num_tokens,
1680
1681
1682
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1683
            multimodal_inputs=self.multimodal_inputs,
1684
1685
1686
1687
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1688
            lora_ids=[req.lora_id for req in self.reqs],
1689
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1690
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1691
            token_type_ids=self.token_type_ids,
1692
1693
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1694
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1695
            capture_hidden_mode=(
1696
                CaptureHiddenMode.FULL
1697
                if self.return_hidden_states
1698
1699
1700
1701
1702
1703
1704
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1705
            ),
1706
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1707
            is_prefill_only=self.is_prefill_only,
1708
1709
        )

1710
    def copy(self):
1711
        # Only contain fields that will be used by process_batch_result
1712
1713
        return ScheduleBatch(
            reqs=self.reqs,
1714
            model_config=self.model_config,
1715
            forward_mode=self.forward_mode,
1716
1717
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1718
            decoding_reqs=self.decoding_reqs,
1719
            spec_algorithm=self.spec_algorithm,
1720
1721
1722
1723
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1724
            is_prefill_only=self.is_prefill_only,
1725
1726
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1727
1728
    def _evict_tree_cache_if_needed(self, num_tokens: int):
        if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
Hanming Lu's avatar
Hanming Lu committed
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

1754
1755
    def __str__(self):
        return (
1756
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1757
1758
1759
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1760

1761
@dataclasses.dataclass
1762
1763
1764
1765
class ModelWorkerBatch:
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1766
    input_ids: torch.Tensor
1767
1768
1769
1770
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1771
    # The indices of output tokens in the token_to_kv_pool_allocator
1772
    out_cache_loc: torch.Tensor
1773
1774
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1775
1776
    seq_lens_sum: int

1777
1778
1779
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1780
    token_ids_logprobs: Optional[List[List[int]]]
1781

Ke Bao's avatar
Ke Bao committed
1782
1783
    # For DP attention
    global_num_tokens: Optional[List[int]]
1784
    global_num_tokens_for_logprob: Optional[List[int]]
1785
    is_extend_in_batch: bool
1786
    can_run_dp_cuda_graph: bool
1787
1788
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1789

1790
    # For extend
1791
    extend_num_tokens: Optional[int]
1792
1793
1794
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1795
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1796
1797

    # For multimodal
Mick's avatar
Mick committed
1798
    multimodal_inputs: Optional[List[MultimodalInputs]]
1799

1800
1801
1802
1803
1804
1805
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1806
    # For LoRA
1807
    lora_ids: Optional[List[str]]
1808
1809
1810

    # Sampling info
    sampling_info: SamplingBatchInfo
1811

1812
1813
1814
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: Optional[torch.Tensor] = None

Rin Intachuen's avatar
Rin Intachuen committed
1815
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1816
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1817

woodx's avatar
woodx committed
1818
1819
1820
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1821
    # Speculative decoding
1822
    spec_algorithm: SpeculativeAlgorithm = None
1823
1824
1825

    spec_info: Optional[SpecInput] = None

1826
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1827
    capture_hidden_mode: CaptureHiddenMode = None
1828
    hicache_consumer_index: int = -1
1829

1830
1831
    # Overlap scheduler related
    delay_sample_launch: bool = False
1832

1833
1834
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False