schedule_batch.py 70.9 KB
Newer Older
1
2
from __future__ import annotations

3
4
import enum

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
34
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import copy
37
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
38
import logging
39
import re
40
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from enum import Enum, auto
42
from http import HTTPStatus
43
from itertools import chain
44
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
47
import torch
48

49
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
51
52
53
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
54
from sglang.srt.disaggregation.utils import DisaggregationMode
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
from sglang.srt.environ import envs
Hanming Lu's avatar
Hanming Lu committed
57
58
59
60
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
61
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
62
from sglang.srt.mem_cache.chunk_cache import SWAChunkCache
63
64
65
from sglang.srt.mem_cache.common import (
    alloc_for_decode,
    alloc_for_extend,
cctry's avatar
cctry committed
66
    evict_from_tree_cache,
67
)
68
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
69
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
70
from sglang.srt.mem_cache.radix_cache import RadixKey
Hanming Lu's avatar
Hanming Lu committed
71
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
72
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
73
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
74
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
75
from sglang.srt.sampling.sampling_params import SamplingParams
76
from sglang.srt.server_args import ServerArgs, get_global_server_args
77
from sglang.srt.utils import flatten_nested_list
Liangsheng Yin's avatar
Liangsheng Yin committed
78

79
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
80
    from sglang.srt.configs.model_config import ModelConfig
81
    from sglang.srt.speculative.eagle_info import EagleDraftInput
82
    from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
83

Liangsheng Yin's avatar
Liangsheng Yin committed
84
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
85

86

Ying Sheng's avatar
Ying Sheng committed
87
88
89
logger = logging.getLogger(__name__)


90
91
92
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
93

94
    def to_json(self):
95
        raise NotImplementedError()
96
97
98


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
99
    def __init__(self, matched: Union[int, List[int]]):
100
101
102
        super().__init__()
        self.matched = matched

103
104
105
106
107
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
108
109


110
111
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
112
        super().__init__()
113
        self.matched = matched
114

115
116
117
118
119
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
120
121


122
123
124
125
126
127
128
129
130
131
132
133
class FINISHED_MATCHED_REGEX(BaseFinishReason):
    def __init__(self, matched: str):
        super().__init__()
        self.matched = matched

    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }


134
135
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
136
        super().__init__()
137
        self.length = length
138

139
140
141
142
143
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
144
145
146


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
147
    def __init__(self, message=None, status_code=None, err_type=None):
148
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
149
        self.message = message or "Aborted"
150
151
        self.status_code = status_code
        self.err_type = err_type
152

153
154
155
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
156
            "message": self.message,
157
158
            "status_code": self.status_code,
            "err_type": self.err_type,
159
        }
160

Lianmin Zheng's avatar
Lianmin Zheng committed
161

Mick's avatar
Mick committed
162
163
164
165
166
167
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

168
169
170
171
172
173
174
175
176
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

177
178
179
180
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
181

182
@dataclasses.dataclass
Mick's avatar
Mick committed
183
184
class MultimodalDataItem:
    """
185
186
187
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
188

189
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
190
    """
191

Mick's avatar
Mick committed
192
193
194
    modality: Modality
    hash: int = None
    pad_value: int = None
195
    offsets: Optional[list] = None
Mick's avatar
Mick committed
196

197
198
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
199
200
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
216

217
218
219
220
221
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
222

223
224
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
225

Mick's avatar
Mick committed
226
227
228
229
230
231
232
233
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
234
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
235
        """
236
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
237

238
        if self.hash is None:
239
240
            if self.feature is not None:
                hashed_feature = self.feature
241
            else:
242
                hashed_feature = self.precomputed_embeddings
243
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
244
245
246
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

247
248
249
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
250
    def is_audio(self):
251
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
252
253

    def is_image(self):
254
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
255
256

    def is_video(self):
257
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
258

259
260
261
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
262
263
264
265
    def validate(self):
        ...
        # TODO

266
267
268
269
270
271
272
273
274
275
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

276
    def merge(self, other):
277
        self.feature += other.feature
278
        self.offsets += other.offsets
279
280
281
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
282
283
284
285
286
287
288

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
289
    image_pad_len: Optional[list] = None
290
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
291

Mick's avatar
Mick committed
292
    # image
Mick's avatar
Mick committed
293
    im_token_id: Optional[int] = None
294
295
296
297
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
298
299
300

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
301

Mick's avatar
Mick committed
302
    # audio
303
304
305
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
306

307
308
309
310
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
311
    @staticmethod
312
    def from_dict(obj: dict):
Mick's avatar
Mick committed
313
        ret = MultimodalInputs(
Mick's avatar
Mick committed
314
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
315
        )
316

Mick's avatar
Mick committed
317
        assert isinstance(ret.mm_items, list)
318
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
319
320
        for item in ret.mm_items:
            item.set_pad_value()
321
322

        optional_args = [
323
324
            "mrope_positions",
            "mrope_position_delta",
325
            "im_token_id",
Mick's avatar
Mick committed
326
327
            "im_start_id",
            "im_end_id",
328
            "video_token_id",
Mick's avatar
Mick committed
329
330
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
331
332
            "audio_start_id",
            "audio_end_id",
333
            "audio_token_id",
334
335
336
337
338
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
339
340
        return ret

Mick's avatar
Mick committed
341
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
342
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
343

344
345
346
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
347
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
348
349
        return any(item.is_audio() for item in self.mm_items)

350
351
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
352
353

    def merge(self, other: MultimodalInputs):
354
355
356
        """
        merge image inputs when requests are being merged
        """
357

358
        # args needed to be merged
359
        optional_args = [
Mick's avatar
Mick committed
360
            "mm_items",
361
            "image_pad_len",
362
363
        ]
        for arg in optional_args:
364
365
366
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
367
368
369
370
371
372
373
374
375
376

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

377
378
379
380
381
382
383
384
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
385
386
387
388
389
390

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
391
        # other args would be kept intact
392

Liangsheng Yin's avatar
Liangsheng Yin committed
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
class RequestStage(str, enum.Enum):
    # prefill
    PREFILL_WAITING = "prefill_waiting"

    # disaggregation prefill
    PREFILL_PREPARE = "prefill_prepare"
    PREFILL_BOOTSTRAP = "prefill_bootstrap"
    PREFILL_FORWARD = "prefill_forward"
    PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

    # disaggregation decode
    DECODE_PREPARE = "decode_prepare"
    DECODE_BOOTSTRAP = "decode_bootstrap"
    DECODE_WAITING = "decode_waiting"
    DECODE_TRANSFERRED = "decode_transferred"


Lianmin Zheng's avatar
Lianmin Zheng committed
411
class Req:
412
    """The input and output status of a request."""
413

414
415
416
417
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
418
        origin_input_ids: List[int],
419
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
422
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
423
        stream: bool = False,
424
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
425
        lora_id: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
426
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
427
        token_type_ids: List[int] = None,
428
        session_id: Optional[str] = None,
429
        custom_logit_processor: Optional[str] = None,
430
        return_hidden_states: bool = False,
431
        eos_token_ids: Optional[Set[int]] = None,
432
        bootstrap_host: Optional[str] = None,
433
        bootstrap_port: Optional[int] = None,
434
        bootstrap_room: Optional[int] = None,
435
        disagg_mode: Optional[DisaggregationMode] = None,
436
        data_parallel_rank: Optional[int] = None,
437
        vocab_size: Optional[int] = None,
438
        priority: Optional[int] = None,
439
        metrics_collector: Optional[SchedulerMetricsCollector] = None,
440
        extra_key: Optional[str] = None,
441
        http_worker_ipc: Optional[str] = None,
442
    ):
443
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
445
        self.origin_input_text = origin_input_text
446
447
448
449
450
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
451
        self.origin_input_ids = origin_input_ids
452
453
454
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
455
        self.fill_ids = []
456
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
457
        self.input_embeds = input_embeds
458

woodx's avatar
woodx committed
459
460
461
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
462
463
464
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

465
466
467
        # For multi-http worker
        self.http_worker_ipc = http_worker_ipc

Lianmin Zheng's avatar
Lianmin Zheng committed
468
        # Sampling info
469
470
471
472
473
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
474
        self.sampling_params = sampling_params
475
        self.custom_logit_processor = custom_logit_processor
476
        self.return_hidden_states = return_hidden_states
477

478
        # extra key for classifying the request (e.g. cache_salt)
479
480
481
482
483
484
        if lora_id is not None:
            extra_key = (
                extra_key or ""
            ) + lora_id  # lora_id is concatenated to the extra key

        self.extra_key = extra_key
485
        self.lora_id = lora_id
Liangsheng Yin's avatar
Liangsheng Yin committed
486

487
        # Memory pool info
488
        self.req_pool_idx: Optional[int] = None
489
        self.mamba_pool_idx: Optional[torch.Tensor] = None  # shape (1)
490

491
492
493
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
494
495
        # finished position (in output_ids), used when checking stop conditions with speculative decoding
        self.finished_len = None
Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
        # Whether this request has finished output
        self.finished_output = None
498
499
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
500
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
501
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
502
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
503
        self.stream = stream
504
        self.eos_token_ids = eos_token_ids
505
        self.vocab_size = vocab_size
506
        self.priority = priority
507

508
        # For incremental decoding
509
510
511
512
513
514
515
516
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
517
518
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
519
        self.decoded_text = ""
520

521
        # For multimodal inputs
Mick's avatar
Mick committed
522
        self.multimodal_inputs: Optional[MultimodalInputs] = None
523

524
        # Prefix info
525
        # The indices to kv cache for the shared prefix.
526
        self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
527
        # Number of tokens to run prefill.
528
        self.extend_input_len = 0
529
530
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
531
532
533
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
534
535
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Ke Bao's avatar
Ke Bao committed
536
537
        # The prefix length of the last prefix matching
        self.last_matched_prefix_len: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
538

539
540
541
542
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
543

544
545
546
        # For retraction
        self.is_retracted = False

547
548
549
550
551
552
553
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

554
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
555
        self.return_logprob = return_logprob
556
        # Start index to compute logprob from.
557
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
558
        self.top_logprobs_num = top_logprobs_num
559
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
560
561
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
562

563
        # Logprobs (return values)
564
565
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
566
567
568
569
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
570
571
572
573
574
575
576
577
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
578
579

        if return_logprob:
580
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
581
582
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
583
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
584
585
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
586
587
588
589
            # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
            self.output_token_ids_logprobs_val: List[
                Union[List[float], torch.Tensor]
            ] = []
590
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
591
592
593
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
594
595
596
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
597
        self.hidden_states: List[List[float]] = []
598
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
599
600
        self.output_topk_p = None
        self.output_topk_index = None
601

602
        # Embedding (return values)
603
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
604

605
        # Constrained decoding
606
        self.grammar: Optional[BaseGrammarObject] = None
607
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
608

609
        # The number of cached tokens that were already cached in the KV cache
610
        self.cached_tokens = 0
611
        self.already_computed = 0
612

613
614
615
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
616

617
618
619
620
        # The number of accepted tokens in speculative decoding for this request.
        # This is used to compute the acceptance rate and average acceptance length per request.
        self.spec_accepted_tokens = 0

621
        # For metrics
622
        self.metrics_collector = metrics_collector
623
        self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
624
        self.has_log_time_stats: bool = False
625
        self.last_tic = time.monotonic()
626

Byron Hsu's avatar
Byron Hsu committed
627
        # For disaggregation
628
        self.bootstrap_host: str = bootstrap_host
629
        self.bootstrap_port: Optional[int] = bootstrap_port
630
        self.bootstrap_room: Optional[int] = bootstrap_room
631
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
632

633
634
635
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
636
637
638
639
640
641
642
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

643
644
645
646
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
647
        self.metadata_buffer_index: int = -1
648

649
650
651
652
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

653
654
655
    @property
    def is_prefill_only(self) -> bool:
        """Check if this request is prefill-only (no token generation needed)."""
656
        # NOTE: when spec is enabled, prefill_only optimizations are disabled
657

658
659
        spec_alg = get_global_server_args().speculative_algorithm
        return self.sampling_params.max_new_tokens == 0 and spec_alg is None
660

661
662
663
664
665
666
667
    @property
    def output_ids_through_stop(self) -> List[int]:
        """Get the output ids through the stop condition. Stop position is included."""
        if self.finished_len is not None:
            return self.output_ids[: self.finished_len]
        return self.output_ids

668
669
670
    def add_latency(self, stage: RequestStage):
        if self.metrics_collector is None:
            return
671

672
        now = time.monotonic()
673
        self.metrics_collector.observe_per_stage_req_latency(
674
675
676
677
            stage.value, now - self.last_tic
        )
        self.last_tic = now

678
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
679
680
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
681
        else:
Mick's avatar
Mick committed
682
            self.multimodal_inputs.merge(image_inputs)
683

684
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
685
        # Whether request reached finished condition
686
687
        return self.finished_reason is not None

688
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
689
        self.fill_ids = self.origin_input_ids + self.output_ids
690
691
692
693
694
695
696
697
        input_len = len(self.fill_ids)
        # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
        max_prefix_len = input_len - 1
        if self.return_logprob:
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
        max_prefix_len = max(max_prefix_len, 0)
        token_ids = self.fill_ids[:max_prefix_len]

698
        if tree_cache is not None:
699
700
701
702
703
704
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
705
706
707
708
709
710
                key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
                **(
                    {"req": self, "cow_mamba": True}
                    if isinstance(tree_cache, MambaRadixCache)
                    else {}
                ),
711
            )
Ke Bao's avatar
Ke Bao committed
712
            self.last_matched_prefix_len = len(self.prefix_indices)
713
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
714

Liangsheng Yin's avatar
Liangsheng Yin committed
715
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
716
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
717
718
        first_iter = self.surr_offset is None or self.read_offset is None

719
720
        output_ids = self.output_ids_through_stop

Liangsheng Yin's avatar
Liangsheng Yin committed
721
722
723
724
725
        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )
726
            self.surr_and_decode_ids = (
727
                self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
728
            )
729
            self.cur_decode_ids_len = len(output_ids)
730
        else:
731
732
            self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
            self.cur_decode_ids_len = len(output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
733

734
        return self.surr_and_decode_ids, self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
735

ybyang's avatar
ybyang committed
736
    def tail_str(self) -> str:
737
738
739
740
741
742
743
744
745
746
747
        # Check stop strings and stop regex patterns together
        if (
            len(self.sampling_params.stop_strs) > 0
            or len(self.sampling_params.stop_regex_strs) > 0
        ):
            max_len_tail_str = max(
                self.sampling_params.stop_str_max_len + 1,
                self.sampling_params.stop_regex_max_len + 1,
            )

        tail_len = min((max_len_tail_str + 1), len(self.output_ids))
ybyang's avatar
ybyang committed
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        return self.tokenizer.decode(self.output_ids[-tail_len:])

    def check_match_stop_str_prefix(self) -> bool:
        """
        Check if the suffix of tail_str overlaps with any stop_str prefix
        """
        if not self.sampling_params.stop_strs:
            return False

        tail_str = self.tail_str()

        # Early return if tail_str is empty
        if not tail_str:
            return False

        for stop_str in self.sampling_params.stop_strs:
            if not stop_str:
                continue
            # Check if stop_str is contained in tail_str (fastest check first)
            if stop_str in tail_str:
                return True

            # Check if tail_str suffix matches stop_str prefix
            # Only check if stop_str is not empty, it's for stream output
            min_len = min(len(tail_str), len(stop_str))
            for i in range(1, min_len + 1):
                if tail_str[-i:] == stop_str[:i]:
                    return True

        return False

779
780
781
    def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
        if self.sampling_params.ignore_eos:
            return False
782

783
784
        # Check stop token ids
        matched_eos = False
785

786
        for i, token_id in enumerate(new_accepted_tokens):
787
            if self.sampling_params.stop_token_ids:
788
                matched_eos |= token_id in self.sampling_params.stop_token_ids
789
            if self.eos_token_ids:
790
                matched_eos |= token_id in self.eos_token_ids
791
            if self.tokenizer is not None:
792
                matched_eos |= token_id == self.tokenizer.eos_token_id
793
                if self.tokenizer.additional_stop_token_ids:
794
                    matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
795
            if matched_eos:
796
797
798
799
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
                matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
                self.finished_len = matched_pos + 1
                return True
800

801
        return False
802

803
    def _check_str_based_finish(self):
804
805
806
807
        if (
            len(self.sampling_params.stop_strs) > 0
            or len(self.sampling_params.stop_regex_strs) > 0
        ):
ybyang's avatar
ybyang committed
808
            tail_str = self.tail_str()
809

810
811
812
813
814
            # Check stop strings
            if len(self.sampling_params.stop_strs) > 0:
                for stop_str in self.sampling_params.stop_strs:
                    if stop_str in tail_str or stop_str in self.decoded_text:
                        self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
815
                        return True
816
817
818
819
820
821
822
823

            # Check stop regex
            if len(self.sampling_params.stop_regex_strs) > 0:
                for stop_regex_str in self.sampling_params.stop_regex_strs:
                    if re.search(stop_regex_str, tail_str):
                        self.finished_reason = FINISHED_MATCHED_REGEX(
                            matched=stop_regex_str
                        )
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
                        return True

        return False

    def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
        for i, token_id in enumerate(new_accepted_tokens):
            if token_id > self.vocab_size or token_id < 0:
                offset = len(self.output_ids) - len(new_accepted_tokens) + i
                if self.sampling_params.stop_token_ids:
                    self.output_ids[offset] = next(
                        iter(self.sampling_params.stop_token_ids)
                    )
                if self.eos_token_ids:
                    self.output_ids[offset] = next(iter(self.eos_token_ids))
                self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
                self.finished_len = offset + 1
                return True

        return False

    def check_finished(self, new_accepted_len: int = 1):
        if self.finished():
            return

        if self.to_abort:
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
            return

        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
            self.finished_len = self.sampling_params.max_new_tokens
            return

        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

        new_accepted_tokens = self.output_ids[-new_accepted_len:]

        if self._check_token_based_finish(new_accepted_tokens):
            return

        if self._check_vocab_boundary_finish(new_accepted_tokens):
            return

        if self._check_str_based_finish():
            return
876

877
    def reset_for_retract(self):
878
        self.prefix_indices = torch.empty((0,), dtype=torch.int64)
879
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
880
        self.swa_uuid_for_lock = None
881
882
        self.extend_input_len = 0
        self.is_retracted = True
883
884
885
886
887
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
888
        self.mamba_pool_idx = None
889
        self.already_computed = 0
890

Lianmin Zheng's avatar
Lianmin Zheng committed
891
892
893
894
895
896
897
898
899
900
901
902
903
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

904
905
906
907
908
909
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
910
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
911
        else:
912
913
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
        logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
914
915
        self.has_log_time_stats = True

916
917
918
919
920
921
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
922
        self.return_logprob = False
923
924
925
926
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
927
    def __repr__(self):
928
        return (
929
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
930
931
932
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
933
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
934
935


936
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
937
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
938
    """Store all information of a batch on the scheduler."""
939

940
    # Request, memory pool, and cache
941
    reqs: List[Req]
942
    req_to_token_pool: ReqToTokenPool = None
943
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
944
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
945
    is_hybrid: bool = False
946

947
    # Batch configs
948
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
949
    forward_mode: ForwardMode = None
950
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
953
954
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
955

956
957
958
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

959
    # Sampling info
960
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
961

962
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
963
    input_ids: torch.Tensor = None  # shape: [b], int64
964
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
965
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
966
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
967
    seq_lens: torch.Tensor = None  # shape: [b], int64
968
    seq_lens_cpu: torch.Tensor = None  # shape: [b], int64
969
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
970
971
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
972

973
974
975
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

976
977
    # The sum of all sequence lengths
    seq_lens_sum: int = None
978
979
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: torch.Tensor = None  # shape: [b], int32
980

Ke Bao's avatar
Ke Bao committed
981
982
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
983
    global_num_tokens_for_logprob: Optional[List[int]] = None
984
    is_extend_in_batch: bool = False
985
    can_run_dp_cuda_graph: bool = False
986
987
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
988

989
    # For processing logprobs
990
    return_logprob: bool = False
991
    top_logprobs_nums: Optional[List[int]] = None
992
    token_ids_logprobs: Optional[List[List[int]]] = None
993

Lianmin Zheng's avatar
Lianmin Zheng committed
994
995
996
997
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

998
999
1000
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
1001
    extend_num_tokens: Optional[int] = None
1002
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1003
    extend_logprob_start_lens: List[int] = None
1004
1005
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
1006

Lianmin Zheng's avatar
Lianmin Zheng committed
1007
    # For encoder-decoder architectures
1008
1009
1010
1011
1012
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

1013
1014
1015
    # Stream
    has_stream: bool = False

1016
1017
    # Has grammar
    has_grammar: bool = False
1018

1019
    # Device
1020
1021
    device: str = "cuda"

1022
    # Speculative decoding
1023
    spec_algorithm: SpeculativeAlgorithm = None
1024
1025
    # spec_info: Optional[SpecInput] = None
    spec_info: Optional[SpecInput] = None
1026

1027
1028
1029
    # Whether to return hidden states
    return_hidden_states: bool = False

1030
1031
1032
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

1033
    # hicache pointer for synchronizing data loading from CPU to GPU
1034
    hicache_consumer_index: int = -1
1035

1036
    @classmethod
1037
1038
    def init_new(
        cls,
1039
        reqs: List[Req],
1040
        req_to_token_pool: ReqToTokenPool,
1041
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
1042
1043
1044
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
1045
        spec_algorithm: SpeculativeAlgorithm,
1046
        chunked_req: Optional[Req] = None,
1047
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1048
1049
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
1050
1051
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1052
1053
1054
1055
            assert (
                tree_cache is None
                or isinstance(tree_cache, SWARadixCache)
                or isinstance(tree_cache, SWAChunkCache)
Hanming Lu's avatar
Hanming Lu committed
1056
1057
1058
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

1059
1060
1061
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
1062
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
1063
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
1064
            is_hybrid=is_hybrid,
1065
            model_config=model_config,
1066
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
1067
            return_logprob=return_logprob,
1068
            has_stream=any(req.stream for req in reqs),
1069
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
1070
            device=req_to_token_pool.device,
1071
            spec_algorithm=spec_algorithm,
1072
            return_hidden_states=any(req.return_hidden_states for req in reqs),
1073
            is_prefill_only=all(req.is_prefill_only for req in reqs),
1074
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
1075
1076
        )

1077
    def batch_size(self):
1078
        return len(self.reqs)
1079

Lianmin Zheng's avatar
Lianmin Zheng committed
1080
1081
1082
    def is_empty(self):
        return len(self.reqs) == 0

1083
1084
1085
1086
1087
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1088
            im = req.multimodal_inputs
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1100
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1113
                # NOTE: the encoder part should be considered as a whole
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1132
1133
            self.device, non_blocking=True
        )
1134
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1135
1136
            self.device, non_blocking=True
        )
1137
        self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1138
1139

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1140
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1141
1142
1143
1144
1145
1146
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1148
1149
1150
1151
1152
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1153
1154
1155
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1156

1157
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1158
1159
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1160
        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1161
        reqs = self.reqs
1162
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1163
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1164
        seq_lens = [len(r.fill_ids) for r in reqs]
1165
        orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1166
1167
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1168

woodx's avatar
woodx committed
1169
1170
1171
1172
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

1173
1174
1175
        input_ids_tensor = torch.tensor(
            list(chain.from_iterable(input_ids)), dtype=torch.int64
        ).to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1176
1177
1178
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1179
        seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1180
1181
1182
        orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
woodx's avatar
woodx committed
1183
1184
1185
1186
1187
1188
1189

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

1190
1191
1192
1193
1194
1195
        # Set batch fields needed by alloc_for_extend
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
        self.seq_lens = seq_lens_tensor
        self.seq_lens_cpu = seq_lens_cpu
        self.extend_num_tokens = extend_num_tokens
1196
1197

        # Allocate memory
1198
1199
        out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
            self
1200
1201
1202
        )

        # Set fields
Rin Intachuen's avatar
Rin Intachuen committed
1203
        input_embeds = []
1204
        extend_input_logprob_token_ids = []
1205
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1206

Lianmin Zheng's avatar
Lianmin Zheng committed
1207
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1208
            req.req_pool_idx = req_pool_indices[i]
1209
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1210

Rin Intachuen's avatar
Rin Intachuen committed
1211
1212
1213
1214
1215
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1216
1217
            multimodal_inputs.append(req.multimodal_inputs)

1218
1219
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1220
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1221

1222
            # Compute the relative logprob_start_len in an extend batch
1223
1224
1225
1226
1227
1228
1229
1230
            #
            # Key variables:
            # - logprob_start_len: Absolute position in full sequence where logprob computation begins
            # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
            # - extend_input_len: Number of tokens that need to be processed in this extend batch
            #   (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
            #    and prefix_indices are the cached/shared prefix tokens)
            #
1231
            if req.logprob_start_len >= pre_len:
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                # Optimization for prefill-only requests: When we only need logprobs at
                # positions beyond the input sequence (to score next-token likelihood), skip all
                # input logprob computation during prefill since no generation will occur.
                if self.is_prefill_only and req.logprob_start_len == len(
                    req.origin_input_ids
                ):
                    # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
                    req.extend_logprob_start_len = req.extend_input_len
                else:
                    # Convert absolute logprob_start_len to relative extend_logprob_start_len
                    #
                    # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
                    # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
                    # This means: "compute logprobs from position 3 onwards in extend batch"
                    req.extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len,
                        req.extend_input_len,
                        req.seqlen - 1,
                    )
1251
            else:
1252
                # logprob_start_len is before the current extend batch, so start from beginning
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1299

Lianmin Zheng's avatar
Lianmin Zheng committed
1300
1301
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
1302
        self.orig_seq_lens = orig_seq_lens_tensor
Lianmin Zheng's avatar
Lianmin Zheng committed
1303
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1304
1305
1306
1307
1308
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1309
1310
1311
1312
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1313
                pixel_values = getattr(mm_item, "feature", None)
1314
                if isinstance(pixel_values, torch.Tensor):
1315
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1316
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1317
        self.token_type_ids = token_type_ids_tensor
1318
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1319

1320
1321
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1322
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1323

1324
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1325
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1326

1327
1328
1329
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1330
        # Build sampling info
1331
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1332
1333
            self,
            self.model_config.vocab_size,
1334
        )
1335

1336
1337
1338
1339
1340
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1341
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1342
        self.forward_mode = ForwardMode.MIXED
1343
        running_bs = running_batch.batch_size()
1344
1345
1346
1347
1348

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1349
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1350
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1351

1352
        self.merge_batch(running_batch)
1353
1354
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1355

1356
1357
1358
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1359
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1360
        self.prefix_lens.extend(
1361
            [
1362
                len(r.origin_input_ids) + len(r.output_ids) + delta
1363
1364
1365
                for r in running_batch.reqs
            ]
        )
1366
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1367
1368
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1369
        self.extend_logprob_start_lens.extend([0] * running_bs)
1370

1371
    def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1372
        page_size = self.token_to_kv_pool_allocator.page_size
1373
1374
1375
1376
1377
        requests = (
            self.reqs
            if selected_indices is None
            else [self.reqs[i] for i in selected_indices]
        )
1378
        if page_size == 1:
1379
            return len(requests)
1380
1381
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1382
        return (
1383
            sum(1 for req in requests if req.seqlen % page_size == 0)
pansicheng's avatar
pansicheng committed
1384
            if self.enable_overlap
1385
            else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
pansicheng's avatar
pansicheng committed
1386
        )
1387

1388
1389
1390
    def check_decode_mem(
        self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
    ):
Hanming Lu's avatar
Hanming Lu committed
1391
        num_tokens = (
1392
            self.new_page_count_next_decode(selected_indices)
1393
1394
1395
1396
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

cctry's avatar
cctry committed
1397
        evict_from_tree_cache(self.tree_cache, num_tokens)
Hanming Lu's avatar
Hanming Lu committed
1398
        return self._is_available_size_sufficient(num_tokens)
1399

1400
    def retract_decode(self, server_args: ServerArgs):
1401
        """Retract the decoding requests when there is not enough memory."""
1402
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1403
1404

        # TODO(lsyin): improve retraction policy for radix cache
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1418
1419
        retracted_reqs = []
        first_iter = True
1420
1421
        while first_iter or (
            not self.check_decode_mem(selected_indices=sorted_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
1422
1423
1424
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1439
1440
                break

1441
            first_iter = False
1442
1443
1444
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)
cctry's avatar
cctry committed
1445
            # release memory and don't insert into the tree because we need the space instantly
1446
            self.release_req(idx, len(sorted_indices), server_args)
Liangsheng Yin's avatar
Liangsheng Yin committed
1447

1448
1449
1450
1451
1452
1453
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1454
        self.filter_batch(keep_indices=sorted_indices)
1455

Liangsheng Yin's avatar
Liangsheng Yin committed
1456
1457
1458
1459
1460
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
1461
1462
            total_decoded_tokens
            + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
1463
1464
1465
        ) / (
            total_max_new_tokens + 1
        )  # avoid zero division
Liangsheng Yin's avatar
Liangsheng Yin committed
1466
1467
        new_estimate_ratio = min(1.0, new_estimate_ratio)

1468
        return retracted_reqs, new_estimate_ratio, []
1469

1470
1471
1472
1473
1474
1475
1476
    def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
        req = self.reqs[idx]

        if server_args.disaggregation_mode == "decode":
            req.offload_kv_cache(
                self.req_to_token_pool, self.token_to_kv_pool_allocator
            )
cctry's avatar
cctry committed
1477
1478
1479
1480
1481
        # TODO (csy): for preempted requests, we may want to insert into the tree
        self.tree_cache.cache_finished_req(req, is_insert=False)
        # NOTE(lsyin): we should use the newly evictable memory instantly.
        num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
        evict_from_tree_cache(self.tree_cache, num_tokens)
1482
1483
1484

        req.reset_for_retract()

1485
1486
1487
1488
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1489
1490
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1491
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1492
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1493
        self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1494
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1495
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1496
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1497
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1498
        self.extend_num_tokens = 0
1499
1500
1501
1502
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1503

1504
1505
1506
1507
1508
    @property
    def is_v2_eagle(self):
        # FIXME: finally deprecate is_v2_eagle
        return self.enable_overlap and self.spec_algorithm.is_eagle()

1509
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1510
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
1512
        bs = len(self.reqs)

1513
        if self.is_v2_eagle:
1514
1515
1516
            # TODO(spec-v2): all v2 spec should go through this path
            draft_input: EagleDraftInput = self.spec_info
            draft_input.prepare_for_decode(self)
1517
1518

        if not self.spec_algorithm.is_none():
1519
1520
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1521
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1522

1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1546
        # Update fields
1547
1548
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1549

1550
1551
1552
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_decode()

1553
1554
1555
1556
        # Allocate memory
        self.out_cache_loc = alloc_for_decode(self, token_per_req=1)

        # Update seq_lens after allocation
1557
        if self.enable_overlap:
1558
1559
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
1560
            self.seq_lens_cpu = self.seq_lens_cpu + 1
1561
            self.orig_seq_lens = self.orig_seq_lens + 1
1562
1563
1564
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1565
            self.seq_lens_cpu.add_(1)
1566
            self.orig_seq_lens.add_(1)
1567
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1568

1569
1570
1571
1572
1573
1574
    def maybe_wait_verify_done(self):
        if self.is_v2_eagle:
            draft_input: EagleDraftInput = self.spec_info
            if draft_input.verify_done is not None:
                draft_input.verify_done.synchronize()

1575
1576
    def filter_batch(
        self,
1577
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1578
1579
        keep_indices: Optional[List[int]] = None,
    ):
1580
1581
1582
1583
        # FIXME(lsyin): used here to get the correct seq_lens
        # The batch has been launched but we need it verified to get correct next batch info
        self.maybe_wait_verify_done()

1584
        if keep_indices is None:
1585
1586
1587
1588
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1589
1590
1591
            keep_indices = [
                i
                for i in range(len(self.reqs))
1592
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1593
                and self.reqs[i] not in chunked_req_to_exclude
1594
1595
1596
            ]

        if keep_indices is None or len(keep_indices) == 0:
1597
1598
1599
1600
            # Filter out all requests
            self.reqs = []
            return

1601
        if len(keep_indices) == len(self.reqs):
1602
1603
1604
            # No need to filter
            return

1605
1606
1607
1608
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1609
        if self.model_config.is_encoder_decoder:
1610
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1611
1612
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1613
        self.reqs = [self.reqs[i] for i in keep_indices]
1614
1615
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1616
1617
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1618
        self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1619
        self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1620
        self.out_cache_loc = None
1621
        self.seq_lens_sum = self.seq_lens.sum().item()
1622
        self.output_ids = self.output_ids[keep_indices_device]
1623
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1624
        if self.return_logprob:
1625
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1626
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1627
1628
        else:
            self.top_logprobs_nums = None
1629
            self.token_ids_logprobs = None
1630

1631
        self.has_stream = any(req.stream for req in self.reqs)
1632
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1633

1634
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1635
        if self.spec_info:
1636
1637
1638
1639
1640
1641
1642
1643
            if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
                has_been_filtered = False
            else:
                has_been_filtered = True
            self.spec_info.filter_batch(
                new_indices=keep_indices_device,
                has_been_filtered=has_been_filtered,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1644

1645
    def merge_batch(self, other: "ScheduleBatch"):
1646
1647
1648
1649
        # NOTE: in v2 eagle mode, we do not need wait verify here because
        # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
        # 2) other batch is always decode, which is finished in previous step

1650
1651
1652
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1653
        self.sampling_info.merge_batch(other.sampling_info)
1654

1655
1656
1657
1658
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1659
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1660
1661
            [self.req_pool_indices, other.req_pool_indices]
        )
1662
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1663
        self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1664
        self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1665
        self.out_cache_loc = None
1666
        self.seq_lens_sum += other.seq_lens_sum
1667
        if self.output_ids is not None:
1668
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1669
1670
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1671
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1672
1673
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1674
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1675
1676
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1677
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1678
        self.reqs.extend(other.reqs)
1679
1680
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1681

1682
1683
1684
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1685
        self.return_hidden_states |= other.return_hidden_states
1686

1687
1688
1689
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1690
1691
1692
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1693
        if self.forward_mode.is_decode_or_idle():
1694
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1695
1696
1697
1698
1699
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1700
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1701
1702
1703
1704
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1705

Lianmin Zheng's avatar
Lianmin Zheng committed
1706
        seq_lens_cpu = (
1707
            seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
1708
1709
        )

1710
1711
1712
1713
1714
        return ModelWorkerBatch(
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
1715
            orig_seq_lens=self.orig_seq_lens,
1716
            out_cache_loc=self.out_cache_loc,
1717
            seq_lens_cpu=seq_lens_cpu,
1718
            seq_lens_sum=self.seq_lens_sum,
1719
1720
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1721
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1722
            global_num_tokens=self.global_num_tokens,
1723
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1724
            is_extend_in_batch=self.is_extend_in_batch,
1725
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1726
1727
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1728
            extend_num_tokens=self.extend_num_tokens,
1729
1730
1731
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1732
            multimodal_inputs=self.multimodal_inputs,
1733
1734
1735
1736
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1737
            lora_ids=[req.lora_id for req in self.reqs],
1738
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1739
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1740
            token_type_ids=self.token_type_ids,
1741
1742
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1743
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1744
            capture_hidden_mode=(
1745
                CaptureHiddenMode.FULL
1746
                if self.return_hidden_states
1747
1748
1749
1750
1751
1752
1753
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1754
            ),
1755
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1756
            is_prefill_only=self.is_prefill_only,
1757
1758
        )

1759
    def copy(self):
1760
        # Only contain fields that will be used by process_batch_result
1761
1762
        return ScheduleBatch(
            reqs=self.reqs,
1763
1764
            req_to_token_pool=self.req_to_token_pool,
            req_pool_indices=self.req_pool_indices,
1765
            model_config=self.model_config,
1766
            forward_mode=self.forward_mode,
1767
1768
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1769
            decoding_reqs=self.decoding_reqs,
1770
            spec_algorithm=self.spec_algorithm,
1771
1772
1773
1774
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1775
            is_prefill_only=self.is_prefill_only,
1776
1777
            seq_lens_cpu=self.seq_lens_cpu,
            enable_overlap=self.enable_overlap,
1778
1779
        )

Hanming Lu's avatar
Hanming Lu committed
1780
1781
1782
1783
1784
1785
1786
1787
1788
    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

1789
1790
    def __str__(self):
        return (
1791
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1792
1793
1794
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1795

1796
@dataclasses.dataclass
1797
1798
1799
1800
class ModelWorkerBatch:
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1801
    input_ids: torch.Tensor
1802
1803
1804
1805
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1806
    # The indices of output tokens in the token_to_kv_pool_allocator
1807
    out_cache_loc: torch.Tensor
1808
1809
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1810
1811
    seq_lens_sum: int

1812
1813
1814
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1815
    token_ids_logprobs: Optional[List[List[int]]]
1816

Ke Bao's avatar
Ke Bao committed
1817
1818
    # For DP attention
    global_num_tokens: Optional[List[int]]
1819
    global_num_tokens_for_logprob: Optional[List[int]]
1820
    is_extend_in_batch: bool
1821
    can_run_dp_cuda_graph: bool
1822
1823
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1824

1825
    # For extend
1826
    extend_num_tokens: Optional[int]
1827
1828
1829
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1830
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1831
1832

    # For multimodal
Mick's avatar
Mick committed
1833
    multimodal_inputs: Optional[List[MultimodalInputs]]
1834

1835
1836
1837
1838
1839
1840
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1841
    # For LoRA
1842
    lora_ids: Optional[List[str]]
1843
1844
1845

    # Sampling info
    sampling_info: SamplingBatchInfo
1846

1847
1848
1849
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: Optional[torch.Tensor] = None

Rin Intachuen's avatar
Rin Intachuen committed
1850
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1851
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1852

woodx's avatar
woodx committed
1853
1854
1855
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1856
    # Speculative decoding
1857
    spec_algorithm: SpeculativeAlgorithm = None
1858
1859
1860

    spec_info: Optional[SpecInput] = None

1861
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1862
    capture_hidden_mode: CaptureHiddenMode = None
1863
    hicache_consumer_index: int = -1
1864

1865
1866
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False