schedule_batch.py 68.6 KB
Newer Older
1
2
from __future__ import annotations

3
4
import enum

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
34
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import copy
37
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
38
import logging
39
import threading
40
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from enum import Enum, auto
42
from http import HTTPStatus
43
from itertools import chain
44
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
47
import torch
48

Liangsheng Yin's avatar
Liangsheng Yin committed
49
from sglang.global_config import global_config
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.disaggregation.utils import DisaggregationMode
56
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Hanming Lu's avatar
Hanming Lu committed
57
58
59
60
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
61
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
62
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
63
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
Yi Zhang's avatar
Yi Zhang committed
64
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
65
from sglang.srt.mem_cache.radix_cache import RadixKey
Hanming Lu's avatar
Hanming Lu committed
66
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
67
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
69
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
70
from sglang.srt.sampling.sampling_params import SamplingParams
71
from sglang.srt.server_args import ServerArgs
72
from sglang.srt.utils import flatten_nested_list
Liangsheng Yin's avatar
Liangsheng Yin committed
73

74
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
75
    from sglang.srt.configs.model_config import ModelConfig
76
    from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
77

Liangsheng Yin's avatar
Liangsheng Yin committed
78
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
79

80
81
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
82
    "mm_attention_backend",
83
84
85
86
87
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
88
    "disable_flashinfer_cutlass_moe_fp4_allgather",
89
90
    "disable_radix_cache",
    "enable_dp_lm_head",
91
    "enable_fp32_lm_head",
92
    "flashinfer_mxfp4_moe_precision",
93
    "enable_flashinfer_allreduce_fusion",
94
95
96
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
    "ep_num_redundant_experts",
97
98
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
99
    "pp_max_micro_batch_size",
100
101
102
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
103
    "speculative_accept_threshold_acc",
104
    "speculative_attention_mode",
105
106
    "torchao_config",
    "triton_attention_reduce_in_fp32",
107
    "num_reserved_decode_tokens",
108
    "weight_loader_disable_mmap",
109
    "enable_multimodal",
110
    "enable_symm_mem",
Lianmin Zheng's avatar
Lianmin Zheng committed
111
    "enable_custom_logit_processor",
112
    "disaggregation_mode",
113
    "enable_deterministic_inference",
fzyzcjy's avatar
fzyzcjy committed
114
115
    "nsa_prefill",
    "nsa_decode",
116
    "multi_item_scoring_delimiter",
117
118
]

119
# Put some global args for easy access
120
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
121

Ying Sheng's avatar
Ying Sheng committed
122
123
124
logger = logging.getLogger(__name__)


125
126
127
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
128

129
    def to_json(self):
130
        raise NotImplementedError()
131
132
133


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
134
    def __init__(self, matched: Union[int, List[int]]):
135
136
137
        super().__init__()
        self.matched = matched

138
139
140
141
142
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
143
144


145
146
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
147
        super().__init__()
148
        self.matched = matched
149

150
151
152
153
154
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
155
156


157
158
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
159
        super().__init__()
160
        self.length = length
161

162
163
164
165
166
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
167
168
169


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
170
    def __init__(self, message=None, status_code=None, err_type=None):
171
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
172
        self.message = message or "Aborted"
173
174
        self.status_code = status_code
        self.err_type = err_type
175

176
177
178
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
179
            "message": self.message,
180
181
            "status_code": self.status_code,
            "err_type": self.err_type,
182
        }
183

Lianmin Zheng's avatar
Lianmin Zheng committed
184

Mick's avatar
Mick committed
185
186
187
188
189
190
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

191
192
193
194
195
196
197
198
199
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

200
201
202
203
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
204

205
@dataclasses.dataclass
Mick's avatar
Mick committed
206
207
class MultimodalDataItem:
    """
208
209
210
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
211

212
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
213
    """
214

Mick's avatar
Mick committed
215
216
217
    modality: Modality
    hash: int = None
    pad_value: int = None
218
    offsets: Optional[list] = None
Mick's avatar
Mick committed
219

220
221
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
222
223
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
239

240
241
242
243
244
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
245

246
247
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
248

Mick's avatar
Mick committed
249
250
251
252
253
254
255
256
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
257
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
258
        """
259
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
260

261
        if self.hash is None:
262
263
            if self.feature is not None:
                hashed_feature = self.feature
264
            else:
265
                hashed_feature = self.precomputed_embeddings
266
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
267
268
269
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

270
271
272
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
273
    def is_audio(self):
274
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
275
276

    def is_image(self):
277
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
278
279

    def is_video(self):
280
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
281

282
283
284
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
285
286
287
288
    def validate(self):
        ...
        # TODO

289
290
291
292
293
294
295
296
297
298
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

299
    def merge(self, other):
300
        self.feature += other.feature
301
        self.offsets += other.offsets
302
303
304
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
305
306
307
308
309
310
311

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
312
    image_pad_len: Optional[list] = None
313
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
314

Mick's avatar
Mick committed
315
    # image
Mick's avatar
Mick committed
316
    im_token_id: Optional[int] = None
317
318
319
320
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
321
322
323

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
324

Mick's avatar
Mick committed
325
    # audio
326
327
328
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
329

330
331
332
333
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
334
    @staticmethod
335
    def from_dict(obj: dict):
Mick's avatar
Mick committed
336
        ret = MultimodalInputs(
Mick's avatar
Mick committed
337
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
338
        )
339

Mick's avatar
Mick committed
340
        assert isinstance(ret.mm_items, list)
341
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
342
343
        for item in ret.mm_items:
            item.set_pad_value()
344
345

        optional_args = [
346
347
            "mrope_positions",
            "mrope_position_delta",
348
            "im_token_id",
Mick's avatar
Mick committed
349
350
            "im_start_id",
            "im_end_id",
351
            "video_token_id",
Mick's avatar
Mick committed
352
353
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
354
355
            "audio_start_id",
            "audio_end_id",
356
            "audio_token_id",
357
358
359
360
361
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
        return ret

Mick's avatar
Mick committed
364
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
365
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
366

367
368
369
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
370
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
371
372
        return any(item.is_audio() for item in self.mm_items)

373
374
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
375
376

    def merge(self, other: MultimodalInputs):
377
378
379
        """
        merge image inputs when requests are being merged
        """
380

381
        # args needed to be merged
382
        optional_args = [
Mick's avatar
Mick committed
383
            "mm_items",
384
            "image_pad_len",
385
386
        ]
        for arg in optional_args:
387
388
389
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
390
391
392
393
394
395
396
397
398
399

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

400
401
402
403
404
405
406
407
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
408
409
410
411
412
413

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
414
        # other args would be kept intact
415

Liangsheng Yin's avatar
Liangsheng Yin committed
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
class RequestStage(str, enum.Enum):
    # prefill
    PREFILL_WAITING = "prefill_waiting"

    # disaggregation prefill
    PREFILL_PREPARE = "prefill_prepare"
    PREFILL_BOOTSTRAP = "prefill_bootstrap"
    PREFILL_FORWARD = "prefill_forward"
    PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

    # disaggregation decode
    DECODE_PREPARE = "decode_prepare"
    DECODE_BOOTSTRAP = "decode_bootstrap"
    DECODE_WAITING = "decode_waiting"
    DECODE_TRANSFERRED = "decode_transferred"


Lianmin Zheng's avatar
Lianmin Zheng committed
434
class Req:
435
    """The input and output status of a request."""
436

437
438
439
440
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
441
        origin_input_ids: List[int],
442
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
445
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
446
        stream: bool = False,
447
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
448
        lora_id: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
449
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
450
        token_type_ids: List[int] = None,
451
        session_id: Optional[str] = None,
452
        custom_logit_processor: Optional[str] = None,
453
        return_hidden_states: bool = False,
454
        eos_token_ids: Optional[Set[int]] = None,
455
        bootstrap_host: Optional[str] = None,
456
        bootstrap_port: Optional[int] = None,
457
        bootstrap_room: Optional[int] = None,
458
        disagg_mode: Optional[DisaggregationMode] = None,
459
        data_parallel_rank: Optional[int] = None,
460
        vocab_size: Optional[int] = None,
461
        priority: Optional[int] = None,
462
        metrics_collector: Optional[SchedulerMetricsCollector] = None,
463
        extra_key: Optional[str] = None,
464
    ):
465
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
466
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
467
        self.origin_input_text = origin_input_text
468
469
470
471
472
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
473
        self.origin_input_ids = origin_input_ids
474
475
476
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
477
        self.fill_ids = []
478
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
479
        self.input_embeds = input_embeds
480

woodx's avatar
woodx committed
481
482
483
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
484
485
486
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
487
        # Sampling info
488
489
490
491
492
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
493
        self.sampling_params = sampling_params
494
        self.custom_logit_processor = custom_logit_processor
495
        self.return_hidden_states = return_hidden_states
496

497
        # extra key for classifying the request (e.g. cache_salt)
498
499
500
501
502
503
        if lora_id is not None:
            extra_key = (
                extra_key or ""
            ) + lora_id  # lora_id is concatenated to the extra key

        self.extra_key = extra_key
504
        self.lora_id = lora_id
Liangsheng Yin's avatar
Liangsheng Yin committed
505

506
        # Memory pool info
507
        self.req_pool_idx: Optional[int] = None
508

509
510
511
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
        # Whether this request has finished output
        self.finished_output = None
514
515
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
516
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
517
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
518
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
519
        self.stream = stream
520
        self.eos_token_ids = eos_token_ids
521
        self.vocab_size = vocab_size
522
        self.priority = priority
523

524
        # For incremental decoding
525
526
527
528
529
530
531
532
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
533
534
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
535
        self.decoded_text = ""
536

537
        # For multimodal inputs
Mick's avatar
Mick committed
538
        self.multimodal_inputs: Optional[MultimodalInputs] = None
539

540
        # Prefix info
541
        # The indices to kv cache for the shared prefix.
542
        self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
543
        # Number of tokens to run prefill.
544
        self.extend_input_len = 0
545
546
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
547
548
549
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
550
551
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Ke Bao's avatar
Ke Bao committed
552
553
        # The prefix length of the last prefix matching
        self.last_matched_prefix_len: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
554

555
556
557
558
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
559

560
561
562
        # For retraction
        self.is_retracted = False

563
564
565
566
567
568
569
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

570
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
571
        self.return_logprob = return_logprob
572
        # Start index to compute logprob from.
573
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
574
        self.top_logprobs_num = top_logprobs_num
575
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
578

579
        # Logprobs (return values)
580
581
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
582
583
584
585
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
586
587
588
589
590
591
592
593
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
594
595

        if return_logprob:
596
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
597
598
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
599
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
602
603
604
605
            # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
            self.output_token_ids_logprobs_val: List[
                Union[List[float], torch.Tensor]
            ] = []
606
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
609
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
610
611
612
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
613
        self.hidden_states: List[List[float]] = []
614
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
615
616
        self.output_topk_p = None
        self.output_topk_index = None
617

618
        # Embedding (return values)
619
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
620

621
        # Constrained decoding
622
        self.grammar: Optional[BaseGrammarObject] = None
623
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
624

625
        # The number of cached tokens that were already cached in the KV cache
626
        self.cached_tokens = 0
627
        self.already_computed = 0
628

629
630
631
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
632
633

        # For metrics
634
        self.metrics_collector = metrics_collector
635
        self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
636
        self.has_log_time_stats: bool = False
637
        self.last_tic = time.monotonic()
638

Byron Hsu's avatar
Byron Hsu committed
639
        # For disaggregation
640
        self.bootstrap_host: str = bootstrap_host
641
        self.bootstrap_port: Optional[int] = bootstrap_port
642
        self.bootstrap_room: Optional[int] = bootstrap_room
643
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
644

645
646
647
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
648
649
650
651
652
653
654
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

655
656
657
658
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
659
        self.metadata_buffer_index: int = -1
660

661
662
663
664
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

665
666
667
    @property
    def is_prefill_only(self) -> bool:
        """Check if this request is prefill-only (no token generation needed)."""
668
        # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
670
671
672
673
        from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

        spec_alg = global_server_args_dict["speculative_algorithm"]
        return self.sampling_params.max_new_tokens == 0 and (
            spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
674
        )
675

676
677
678
    def add_latency(self, stage: RequestStage):
        if self.metrics_collector is None:
            return
679

680
        now = time.monotonic()
681
        self.metrics_collector.observe_per_stage_req_latency(
682
683
684
685
            stage.value, now - self.last_tic
        )
        self.last_tic = now

686
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
687
688
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
689
        else:
Mick's avatar
Mick committed
690
            self.multimodal_inputs.merge(image_inputs)
691

692
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
693
        # Whether request reached finished condition
694
695
        return self.finished_reason is not None

696
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
697
        self.fill_ids = self.origin_input_ids + self.output_ids
698
699
700
701
702
703
704
705
        input_len = len(self.fill_ids)
        # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
        max_prefix_len = input_len - 1
        if self.return_logprob:
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
        max_prefix_len = max(max_prefix_len, 0)
        token_ids = self.fill_ids[:max_prefix_len]

706
        if tree_cache is not None:
707
708
709
710
711
712
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
713
                key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
714
            )
Ke Bao's avatar
Ke Bao committed
715
            self.last_matched_prefix_len = len(self.prefix_indices)
716
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
717

Liangsheng Yin's avatar
Liangsheng Yin committed
718
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
719
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
720
721
722
723
724
725
726
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )
727
728
729
730
731
732
733
            self.surr_and_decode_ids = (
                self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
            )
            self.cur_decode_ids_len = len(self.output_ids)
        else:
            self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
            self.cur_decode_ids_len = len(self.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
734

735
        return self.surr_and_decode_ids, self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
736

737
    def check_finished(self):
738
        if self.finished():
739
740
            return

741
        if self.to_abort:
742
743
744
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
745
746
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
747
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
748
749
750
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
751
752
            return

753
754
755
756
757
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

758
        last_token_id = self.output_ids[-1]
759

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
777

778
779
780
781
782
783
784
785
        if last_token_id > self.vocab_size or last_token_id < 0:
            if self.sampling_params.stop_token_ids:
                self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
            if self.eos_token_ids:
                self.output_ids[-1] = next(iter(self.eos_token_ids))
            self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
            return

786
        # Check stop strings
787
788
789
790
791
792
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
793
                if stop_str in tail_str or stop_str in self.decoded_text:
794
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
795
796
                    return

797
    def reset_for_retract(self):
798
        self.prefix_indices = torch.empty((0,), dtype=torch.int64)
799
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
800
        self.swa_uuid_for_lock = None
801
802
        self.extend_input_len = 0
        self.is_retracted = True
803
804
805
806
807
808
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
809
        self.already_computed = 0
810

Lianmin Zheng's avatar
Lianmin Zheng committed
811
812
813
814
815
816
817
818
819
820
821
822
823
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

824
825
826
827
828
829
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
830
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
831
        else:
832
833
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
        logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
834
835
        self.has_log_time_stats = True

836
837
838
839
840
841
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
842
        self.return_logprob = False
843
844
845
846
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
847
    def __repr__(self):
848
        return (
849
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
850
851
852
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
853
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
854
855


856
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
857
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
858
    """Store all information of a batch on the scheduler."""
859

860
    # Request, memory pool, and cache
861
    reqs: List[Req]
862
    req_to_token_pool: ReqToTokenPool = None
863
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
864
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
865
    is_hybrid: bool = False
866

867
    # Batch configs
868
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
869
    forward_mode: ForwardMode = None
870
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
871
872
873
874
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
875

876
877
878
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

879
    # Sampling info
880
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
881

882
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
883
    input_ids: torch.Tensor = None  # shape: [b], int64
884
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
885
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
886
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
887
    seq_lens: torch.Tensor = None  # shape: [b], int64
888
    seq_lens_cpu: torch.Tensor = None  # shape: [b], int64
889
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
892

893
894
895
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

896
897
    # The sum of all sequence lengths
    seq_lens_sum: int = None
898
899
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: torch.Tensor = None  # shape: [b], int32
900

Ke Bao's avatar
Ke Bao committed
901
902
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
903
    global_num_tokens_for_logprob: Optional[List[int]] = None
904
    is_extend_in_batch: bool = False
905
    can_run_dp_cuda_graph: bool = False
906
907
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
908

909
    # For processing logprobs
910
    return_logprob: bool = False
911
    top_logprobs_nums: Optional[List[int]] = None
912
    token_ids_logprobs: Optional[List[List[int]]] = None
913

Lianmin Zheng's avatar
Lianmin Zheng committed
914
915
916
917
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

918
919
920
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
921
    extend_num_tokens: Optional[int] = None
922
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
923
    extend_logprob_start_lens: List[int] = None
924
925
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
926

Lianmin Zheng's avatar
Lianmin Zheng committed
927
    # For encoder-decoder architectures
928
929
930
931
932
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

933
934
935
    # Stream
    has_stream: bool = False

936
937
    # Has grammar
    has_grammar: bool = False
938

939
    # Device
940
941
    device: str = "cuda"

942
    # Speculative decoding
943
    spec_algorithm: SpeculativeAlgorithm = None
944
945
    # spec_info: Optional[SpecInput] = None
    spec_info: Optional[SpecInput] = None
946

947
948
949
    # Whether to return hidden states
    return_hidden_states: bool = False

950
951
952
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

953
    # hicache pointer for synchronizing data loading from CPU to GPU
954
    hicache_consumer_index: int = -1
955

956
    @classmethod
957
958
    def init_new(
        cls,
959
        reqs: List[Req],
960
        req_to_token_pool: ReqToTokenPool,
961
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
962
963
964
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
965
        spec_algorithm: SpeculativeAlgorithm,
966
        chunked_req: Optional[Req] = None,
967
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
968
969
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
970
971
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
972
973
974
975
            assert (
                tree_cache is None
                or isinstance(tree_cache, SWARadixCache)
                or isinstance(tree_cache, SWAChunkCache)
Hanming Lu's avatar
Hanming Lu committed
976
977
978
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

979
980
981
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
982
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
983
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
984
            is_hybrid=is_hybrid,
985
            model_config=model_config,
986
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
987
            return_logprob=return_logprob,
988
            has_stream=any(req.stream for req in reqs),
989
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
990
            device=req_to_token_pool.device,
991
            spec_algorithm=spec_algorithm,
992
            return_hidden_states=any(req.return_hidden_states for req in reqs),
993
            is_prefill_only=all(req.is_prefill_only for req in reqs),
994
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
995
996
        )

997
    def batch_size(self):
998
        return len(self.reqs)
999

Lianmin Zheng's avatar
Lianmin Zheng committed
1000
1001
1002
    def is_empty(self):
        return len(self.reqs) == 0

1003
1004
1005
1006
1007
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1008
            im = req.multimodal_inputs
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1020
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1033
                # NOTE: the encoder part should be considered as a whole
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1051
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1052
1053
            self.device, non_blocking=True
        )
1054
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1055
1056
            self.device, non_blocking=True
        )
1057
        self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1058
1059

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1060
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1061
1062
1063
1064
1065
1066
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1067
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1068
1069
1070
1071
1072
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1073
1074
1075
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1076

1077
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1078
1079
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1080
        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1081
        reqs = self.reqs
1082
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1083
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1084
        seq_lens = [len(r.fill_ids) for r in reqs]
1085
        orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1086
1087
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1088

woodx's avatar
woodx committed
1089
1090
1091
1092
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

1093
1094
1095
        input_ids_tensor = torch.tensor(
            list(chain.from_iterable(input_ids)), dtype=torch.int64
        ).to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1096
1097
1098
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1099
        seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1100
1101
1102
        orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
woodx's avatar
woodx committed
1103
1104
1105
1106
1107
1108
1109

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

1110
1111
1112
1113
1114
1115
        # Set batch fields needed by alloc_for_extend
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
        self.seq_lens = seq_lens_tensor
        self.seq_lens_cpu = seq_lens_cpu
        self.extend_num_tokens = extend_num_tokens
1116
1117

        # Allocate memory
1118
1119
        out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
            self
1120
1121
1122
        )

        # Set fields
Rin Intachuen's avatar
Rin Intachuen committed
1123
        input_embeds = []
1124
        extend_input_logprob_token_ids = []
1125
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1126

Lianmin Zheng's avatar
Lianmin Zheng committed
1127
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1128
            req.req_pool_idx = req_pool_indices[i]
1129
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1130

Rin Intachuen's avatar
Rin Intachuen committed
1131
1132
1133
1134
1135
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1136
1137
            multimodal_inputs.append(req.multimodal_inputs)

1138
1139
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1140
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1141

1142
            # Compute the relative logprob_start_len in an extend batch
1143
1144
1145
1146
1147
1148
1149
1150
            #
            # Key variables:
            # - logprob_start_len: Absolute position in full sequence where logprob computation begins
            # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
            # - extend_input_len: Number of tokens that need to be processed in this extend batch
            #   (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
            #    and prefix_indices are the cached/shared prefix tokens)
            #
1151
            if req.logprob_start_len >= pre_len:
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
                # Optimization for prefill-only requests: When we only need logprobs at
                # positions beyond the input sequence (to score next-token likelihood), skip all
                # input logprob computation during prefill since no generation will occur.
                if self.is_prefill_only and req.logprob_start_len == len(
                    req.origin_input_ids
                ):
                    # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
                    req.extend_logprob_start_len = req.extend_input_len
                else:
                    # Convert absolute logprob_start_len to relative extend_logprob_start_len
                    #
                    # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
                    # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
                    # This means: "compute logprobs from position 3 onwards in extend batch"
                    req.extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len,
                        req.extend_input_len,
                        req.seqlen - 1,
                    )
1171
            else:
1172
                # logprob_start_len is before the current extend batch, so start from beginning
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1219

Lianmin Zheng's avatar
Lianmin Zheng committed
1220
1221
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
1222
        self.orig_seq_lens = orig_seq_lens_tensor
Lianmin Zheng's avatar
Lianmin Zheng committed
1223
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1224
1225
1226
1227
1228
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1229
1230
1231
1232
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1233
                pixel_values = getattr(mm_item, "feature", None)
1234
                if isinstance(pixel_values, torch.Tensor):
1235
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1236
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1237
        self.token_type_ids = token_type_ids_tensor
1238
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1239

1240
1241
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1242
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1243

1244
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1245
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1246

1247
1248
1249
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1250
        # Build sampling info
1251
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1252
1253
            self,
            self.model_config.vocab_size,
1254
        )
1255

1256
1257
1258
1259
1260
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1261
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1262
        self.forward_mode = ForwardMode.MIXED
1263
        running_bs = running_batch.batch_size()
1264
1265
1266
1267
1268

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1269
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1270
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1271

1272
        self.merge_batch(running_batch)
1273
1274
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1275

1276
1277
1278
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1279
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1280
        self.prefix_lens.extend(
1281
            [
1282
                len(r.origin_input_ids) + len(r.output_ids) + delta
1283
1284
1285
                for r in running_batch.reqs
            ]
        )
1286
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1287
1288
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1289
        self.extend_logprob_start_lens.extend([0] * running_bs)
1290

1291
    def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1292
        page_size = self.token_to_kv_pool_allocator.page_size
1293
1294
1295
1296
1297
        requests = (
            self.reqs
            if selected_indices is None
            else [self.reqs[i] for i in selected_indices]
        )
1298
        if page_size == 1:
1299
            return len(requests)
1300
1301
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1302
        return (
1303
            sum(1 for req in requests if req.seqlen % page_size == 0)
pansicheng's avatar
pansicheng committed
1304
            if self.enable_overlap
1305
            else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
pansicheng's avatar
pansicheng committed
1306
        )
1307

1308
1309
1310
    def check_decode_mem(
        self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
    ):
Hanming Lu's avatar
Hanming Lu committed
1311
        num_tokens = (
1312
            self.new_page_count_next_decode(selected_indices)
1313
1314
1315
1316
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1317
1318
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1319

1320
    def retract_decode(self, server_args: ServerArgs):
1321
        """Retract the decoding requests when there is not enough memory."""
1322
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1323
1324

        # TODO(lsyin): improve retraction policy for radix cache
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1338
1339
        retracted_reqs = []
        first_iter = True
1340
1341
        while first_iter or (
            not self.check_decode_mem(selected_indices=sorted_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
1342
1343
1344
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1359
1360
                break

1361
            first_iter = False
1362
1363
1364
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)
1365
            self.release_req(idx, len(sorted_indices), server_args)
Liangsheng Yin's avatar
Liangsheng Yin committed
1366

1367
1368
1369
1370
1371
1372
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1373
        self.filter_batch(keep_indices=sorted_indices)
1374

Liangsheng Yin's avatar
Liangsheng Yin committed
1375
1376
1377
1378
1379
1380
1381
1382
1383
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

1384
        return retracted_reqs, new_estimate_ratio, []
1385

1386
1387
    def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
        req = self.reqs[idx]
1388
        seq_lens_cpu = self.seq_lens_cpu.numpy()
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423

        if server_args.disaggregation_mode == "decode":
            req.offload_kv_cache(
                self.req_to_token_pool, self.token_to_kv_pool_allocator
            )
        if isinstance(self.tree_cache, ChunkCache):
            # ChunkCache does not have eviction
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)
        else:
            # TODO: apply more fine-grained retraction
            last_uncached_pos = (
                len(req.prefix_indices) // server_args.page_size
            ) * server_args.page_size
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)

            # release the last node
            if self.is_hybrid:
                self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
            else:
                self.tree_cache.dec_lock_ref(req.last_node)

            # NOTE(lsyin): we should use the newly evictable memory instantly.
            num_tokens = remaing_req_count * global_config.retract_decode_steps
            self._evict_tree_cache_if_needed(num_tokens)

        req.reset_for_retract()

1424
1425
1426
1427
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1428
1429
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1430
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1431
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1432
        self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1433
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1434
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1435
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1436
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1437
        self.extend_num_tokens = 0
1438
1439
1440
1441
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1442

1443
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1444
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1445
1446
        bs = len(self.reqs)

1447
1448
1449
        if (
            self.spec_algorithm.is_eagle()
            or self.spec_algorithm.is_standalone()
1450
            or self.spec_algorithm.is_ngram()
1451
        ):
1452
1453
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1454
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1455

1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1479
        # Update fields
1480
1481
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1482

1483
1484
1485
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_decode()

1486
1487
1488
1489
        # Allocate memory
        self.out_cache_loc = alloc_for_decode(self, token_per_req=1)

        # Update seq_lens after allocation
1490
        if self.enable_overlap:
1491
1492
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
1493
            self.seq_lens_cpu = self.seq_lens_cpu + 1
1494
            self.orig_seq_lens = self.orig_seq_lens + 1
1495
1496
1497
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1498
            self.seq_lens_cpu.add_(1)
1499
            self.orig_seq_lens.add_(1)
1500
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1501

1502
1503
    def filter_batch(
        self,
1504
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1505
1506
1507
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1508
1509
1510
1511
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1512
1513
1514
            keep_indices = [
                i
                for i in range(len(self.reqs))
1515
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
                and self.reqs[i] not in chunked_req_to_exclude
1517
1518
1519
            ]

        if keep_indices is None or len(keep_indices) == 0:
1520
1521
1522
1523
            # Filter out all requests
            self.reqs = []
            return

1524
        if len(keep_indices) == len(self.reqs):
1525
1526
1527
            # No need to filter
            return

1528
1529
1530
1531
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1532
        if self.model_config.is_encoder_decoder:
1533
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1534
1535
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1536
        self.reqs = [self.reqs[i] for i in keep_indices]
1537
1538
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1539
1540
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1541
        self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1542
        self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1543
        self.out_cache_loc = None
1544
        self.seq_lens_sum = self.seq_lens.sum().item()
1545
        self.output_ids = self.output_ids[keep_indices_device]
1546
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1547
        if self.return_logprob:
1548
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1549
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1550
1551
        else:
            self.top_logprobs_nums = None
1552
            self.token_ids_logprobs = None
1553

1554
        self.has_stream = any(req.stream for req in self.reqs)
1555
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1556

1557
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1558
        if self.spec_info:
1559
1560
1561
1562
1563
1564
1565
1566
            if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
                has_been_filtered = False
            else:
                has_been_filtered = True
            self.spec_info.filter_batch(
                new_indices=keep_indices_device,
                has_been_filtered=has_been_filtered,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1567

1568
    def merge_batch(self, other: "ScheduleBatch"):
1569
1570
1571
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1572
        self.sampling_info.merge_batch(other.sampling_info)
1573

1574
1575
1576
1577
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1578
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1579
1580
            [self.req_pool_indices, other.req_pool_indices]
        )
1581
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1582
        self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1583
        self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1584
        self.out_cache_loc = None
1585
        self.seq_lens_sum += other.seq_lens_sum
1586
        if self.output_ids is not None:
1587
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1588
1589
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1590
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1591
1592
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1593
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1594
1595
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1596
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1597
        self.reqs.extend(other.reqs)
1598
1599
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1600

1601
1602
1603
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1604
        self.return_hidden_states |= other.return_hidden_states
1605

1606
1607
1608
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1609
1610
1611
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1612
        if self.forward_mode.is_decode_or_idle():
1613
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1614
1615
1616
1617
1618
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1619
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1620
1621
1622
1623
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1624

Lianmin Zheng's avatar
Lianmin Zheng committed
1625
        seq_lens_cpu = (
1626
            seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
1627
1628
        )

1629
1630
1631
1632
1633
        return ModelWorkerBatch(
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
1634
            orig_seq_lens=self.orig_seq_lens,
1635
            out_cache_loc=self.out_cache_loc,
1636
            seq_lens_cpu=seq_lens_cpu,
1637
            seq_lens_sum=self.seq_lens_sum,
1638
1639
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1640
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1641
            global_num_tokens=self.global_num_tokens,
1642
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1643
            is_extend_in_batch=self.is_extend_in_batch,
1644
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1645
1646
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1647
            extend_num_tokens=self.extend_num_tokens,
1648
1649
1650
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1651
            multimodal_inputs=self.multimodal_inputs,
1652
1653
1654
1655
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1656
            lora_ids=[req.lora_id for req in self.reqs],
1657
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1658
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1659
            token_type_ids=self.token_type_ids,
1660
1661
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1662
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1663
            capture_hidden_mode=(
1664
                CaptureHiddenMode.FULL
1665
                if self.return_hidden_states
1666
1667
1668
1669
1670
1671
1672
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1673
            ),
1674
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1675
            is_prefill_only=self.is_prefill_only,
1676
1677
        )

1678
    def copy(self):
1679
        # Only contain fields that will be used by process_batch_result
1680
1681
        return ScheduleBatch(
            reqs=self.reqs,
1682
            model_config=self.model_config,
1683
            forward_mode=self.forward_mode,
1684
1685
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1686
            decoding_reqs=self.decoding_reqs,
1687
            spec_algorithm=self.spec_algorithm,
1688
1689
1690
1691
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1692
            is_prefill_only=self.is_prefill_only,
1693
1694
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1695
1696
    def _evict_tree_cache_if_needed(self, num_tokens: int):
        if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
Hanming Lu's avatar
Hanming Lu committed
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

1722
1723
    def __str__(self):
        return (
1724
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1725
1726
1727
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1728

1729
@dataclasses.dataclass
1730
1731
1732
1733
class ModelWorkerBatch:
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1734
    input_ids: torch.Tensor
1735
1736
1737
1738
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1739
    # The indices of output tokens in the token_to_kv_pool_allocator
1740
    out_cache_loc: torch.Tensor
1741
1742
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1743
1744
    seq_lens_sum: int

1745
1746
1747
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1748
    token_ids_logprobs: Optional[List[List[int]]]
1749

Ke Bao's avatar
Ke Bao committed
1750
1751
    # For DP attention
    global_num_tokens: Optional[List[int]]
1752
    global_num_tokens_for_logprob: Optional[List[int]]
1753
    is_extend_in_batch: bool
1754
    can_run_dp_cuda_graph: bool
1755
1756
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1757

1758
    # For extend
1759
    extend_num_tokens: Optional[int]
1760
1761
1762
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1763
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1764
1765

    # For multimodal
Mick's avatar
Mick committed
1766
    multimodal_inputs: Optional[List[MultimodalInputs]]
1767

1768
1769
1770
1771
1772
1773
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1774
    # For LoRA
1775
    lora_ids: Optional[List[str]]
1776
1777
1778

    # Sampling info
    sampling_info: SamplingBatchInfo
1779

1780
1781
1782
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: Optional[torch.Tensor] = None

Rin Intachuen's avatar
Rin Intachuen committed
1783
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1784
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1785

woodx's avatar
woodx committed
1786
1787
1788
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1789
    # Speculative decoding
1790
    spec_algorithm: SpeculativeAlgorithm = None
1791
1792
1793

    spec_info: Optional[SpecInput] = None

1794
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1795
    capture_hidden_mode: CaptureHiddenMode = None
1796
    hicache_consumer_index: int = -1
1797

1798
1799
    # Overlap scheduler related
    delay_sample_launch: bool = False
1800

1801
1802
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False