schedule_batch.py 74.5 KB
Newer Older
1
2
from __future__ import annotations

3
4
import enum

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
34
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import copy
37
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
38
import logging
39
import re
40
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from enum import Enum, auto
42
from http import HTTPStatus
43
from itertools import chain
44
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
47
import torch
48

49
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
51
52
53
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
54
from sglang.srt.disaggregation.utils import DisaggregationMode
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
from sglang.srt.environ import envs
Hanming Lu's avatar
Hanming Lu committed
57
58
59
60
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
61
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
62
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
63
64
65
66
67
from sglang.srt.mem_cache.common import (
    alloc_for_decode,
    alloc_for_extend,
    alloc_token_slots,
)
68
69
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
70
from sglang.srt.mem_cache.radix_cache import RadixKey
Hanming Lu's avatar
Hanming Lu committed
71
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
72
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
73
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
74
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
75
from sglang.srt.sampling.sampling_params import SamplingParams
76
from sglang.srt.server_args import ServerArgs
77
from sglang.srt.utils import flatten_nested_list
78
from sglang.srt.utils.common import next_power_of_2
Liangsheng Yin's avatar
Liangsheng Yin committed
79

80
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
81
    from sglang.srt.configs.model_config import ModelConfig
82
    from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
83

Liangsheng Yin's avatar
Liangsheng Yin committed
84
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
85

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
    "mm_attention_backend",
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
    "disable_flashinfer_cutlass_moe_fp4_allgather",
    "disable_radix_cache",
    "enable_dp_lm_head",
    "enable_fp32_lm_head",
    "flashinfer_mxfp4_moe_precision",
    "enable_flashinfer_allreduce_fusion",
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
    "ep_num_redundant_experts",
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "pp_max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
    "speculative_accept_threshold_acc",
    "speculative_attention_mode",
    "torchao_config",
    "triton_attention_reduce_in_fp32",
    "num_reserved_decode_tokens",
    "weight_loader_disable_mmap",
    "enable_multimodal",
    "enable_symm_mem",
    "enable_custom_logit_processor",
    "disaggregation_mode",
    "enable_deterministic_inference",
    "nsa_prefill",
    "nsa_decode",
    "multi_item_scoring_delimiter",
]

# Put some global args for easy access
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
127

Ying Sheng's avatar
Ying Sheng committed
128
129
130
logger = logging.getLogger(__name__)


131
132
133
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
134

135
    def to_json(self):
136
        raise NotImplementedError()
137
138
139


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
140
    def __init__(self, matched: Union[int, List[int]]):
141
142
143
        super().__init__()
        self.matched = matched

144
145
146
147
148
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
149
150


151
152
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
153
        super().__init__()
154
        self.matched = matched
155

156
157
158
159
160
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
161
162


163
164
165
166
167
168
169
170
171
172
173
174
class FINISHED_MATCHED_REGEX(BaseFinishReason):
    def __init__(self, matched: str):
        super().__init__()
        self.matched = matched

    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }


175
176
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
177
        super().__init__()
178
        self.length = length
179

180
181
182
183
184
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
185
186
187


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
188
    def __init__(self, message=None, status_code=None, err_type=None):
189
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        self.message = message or "Aborted"
191
192
        self.status_code = status_code
        self.err_type = err_type
193

194
195
196
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
197
            "message": self.message,
198
199
            "status_code": self.status_code,
            "err_type": self.err_type,
200
        }
201

Lianmin Zheng's avatar
Lianmin Zheng committed
202

Mick's avatar
Mick committed
203
204
205
206
207
208
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

209
210
211
212
213
214
215
216
217
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

218
219
220
221
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
222

223
@dataclasses.dataclass
Mick's avatar
Mick committed
224
225
class MultimodalDataItem:
    """
226
227
228
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
229

230
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
231
    """
232

Mick's avatar
Mick committed
233
234
235
    modality: Modality
    hash: int = None
    pad_value: int = None
236
    offsets: Optional[list] = None
Mick's avatar
Mick committed
237

238
239
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
240
241
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
257

258
259
260
261
262
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
263

264
265
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
266

Mick's avatar
Mick committed
267
268
269
270
271
272
273
274
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
275
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
276
        """
277
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
278

279
        if self.hash is None:
280
281
            if self.feature is not None:
                hashed_feature = self.feature
282
            else:
283
                hashed_feature = self.precomputed_embeddings
284
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
285
286
287
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

288
289
290
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
291
    def is_audio(self):
292
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
293
294

    def is_image(self):
295
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
296
297

    def is_video(self):
298
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
299

300
301
302
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
303
304
305
306
    def validate(self):
        ...
        # TODO

307
308
309
310
311
312
313
314
315
316
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

317
    def merge(self, other):
318
        self.feature += other.feature
319
        self.offsets += other.offsets
320
321
322
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
323
324
325
326
327
328
329

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
330
    image_pad_len: Optional[list] = None
331
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
332

Mick's avatar
Mick committed
333
    # image
Mick's avatar
Mick committed
334
    im_token_id: Optional[int] = None
335
336
337
338
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
339
340
341

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
342

Mick's avatar
Mick committed
343
    # audio
344
345
346
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
347

348
349
350
351
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
352
    @staticmethod
353
    def from_dict(obj: dict):
Mick's avatar
Mick committed
354
        ret = MultimodalInputs(
Mick's avatar
Mick committed
355
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
356
        )
357

Mick's avatar
Mick committed
358
        assert isinstance(ret.mm_items, list)
359
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
360
361
        for item in ret.mm_items:
            item.set_pad_value()
362
363

        optional_args = [
364
365
            "mrope_positions",
            "mrope_position_delta",
366
            "im_token_id",
Mick's avatar
Mick committed
367
368
            "im_start_id",
            "im_end_id",
369
            "video_token_id",
Mick's avatar
Mick committed
370
371
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
372
373
            "audio_start_id",
            "audio_end_id",
374
            "audio_token_id",
375
376
377
378
379
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
380
381
        return ret

Mick's avatar
Mick committed
382
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
383
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
384

385
386
387
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
388
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
389
390
        return any(item.is_audio() for item in self.mm_items)

391
392
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
393
394

    def merge(self, other: MultimodalInputs):
395
396
397
        """
        merge image inputs when requests are being merged
        """
398

399
        # args needed to be merged
400
        optional_args = [
Mick's avatar
Mick committed
401
            "mm_items",
402
            "image_pad_len",
403
404
        ]
        for arg in optional_args:
405
406
407
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
408
409
410
411
412
413
414
415
416
417

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

418
419
420
421
422
423
424
425
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
426
427
428
429
430
431

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
432
        # other args would be kept intact
433

Liangsheng Yin's avatar
Liangsheng Yin committed
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
class RequestStage(str, enum.Enum):
    # prefill
    PREFILL_WAITING = "prefill_waiting"

    # disaggregation prefill
    PREFILL_PREPARE = "prefill_prepare"
    PREFILL_BOOTSTRAP = "prefill_bootstrap"
    PREFILL_FORWARD = "prefill_forward"
    PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

    # disaggregation decode
    DECODE_PREPARE = "decode_prepare"
    DECODE_BOOTSTRAP = "decode_bootstrap"
    DECODE_WAITING = "decode_waiting"
    DECODE_TRANSFERRED = "decode_transferred"


Lianmin Zheng's avatar
Lianmin Zheng committed
452
class Req:
453
    """The input and output status of a request."""
454

455
456
457
458
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
459
        origin_input_ids: List[int],
460
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
463
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        stream: bool = False,
465
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
466
        lora_id: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
467
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
468
        token_type_ids: List[int] = None,
469
        session_id: Optional[str] = None,
470
        custom_logit_processor: Optional[str] = None,
471
        return_hidden_states: bool = False,
472
        eos_token_ids: Optional[Set[int]] = None,
473
        bootstrap_host: Optional[str] = None,
474
        bootstrap_port: Optional[int] = None,
475
        bootstrap_room: Optional[int] = None,
476
        disagg_mode: Optional[DisaggregationMode] = None,
477
        data_parallel_rank: Optional[int] = None,
478
        vocab_size: Optional[int] = None,
479
        priority: Optional[int] = None,
480
        metrics_collector: Optional[SchedulerMetricsCollector] = None,
481
        extra_key: Optional[str] = None,
482
    ):
483
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
484
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
485
        self.origin_input_text = origin_input_text
486
487
488
489
490
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
491
        self.origin_input_ids = origin_input_ids
492
493
494
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
495
        self.fill_ids = []
496
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
497
        self.input_embeds = input_embeds
498

woodx's avatar
woodx committed
499
500
501
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
502
503
504
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
505
        # Sampling info
506
507
508
509
510
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
511
        self.sampling_params = sampling_params
512
        self.custom_logit_processor = custom_logit_processor
513
        self.return_hidden_states = return_hidden_states
514

515
        # extra key for classifying the request (e.g. cache_salt)
516
517
518
519
520
521
        if lora_id is not None:
            extra_key = (
                extra_key or ""
            ) + lora_id  # lora_id is concatenated to the extra key

        self.extra_key = extra_key
522
        self.lora_id = lora_id
Liangsheng Yin's avatar
Liangsheng Yin committed
523

524
        # Memory pool info
525
        self.req_pool_idx: Optional[int] = None
526
        self.mamba_pool_idx: Optional[torch.Tensor] = None  # shape (1)
527

528
529
530
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
        # Whether this request has finished output
        self.finished_output = None
533
534
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
535
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
536
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
537
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
538
        self.stream = stream
539
        self.eos_token_ids = eos_token_ids
540
        self.vocab_size = vocab_size
541
        self.priority = priority
542

543
        # For incremental decoding
544
545
546
547
548
549
550
551
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
552
553
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
554
        self.decoded_text = ""
555

556
        # For multimodal inputs
Mick's avatar
Mick committed
557
        self.multimodal_inputs: Optional[MultimodalInputs] = None
558

559
        # Prefix info
560
        # The indices to kv cache for the shared prefix.
561
        self.prefix_indices: torch.Tensor = torch.empty((0,), dtype=torch.int64)
562
        # Number of tokens to run prefill.
563
        self.extend_input_len = 0
564
565
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
566
567
568
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
569
570
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Ke Bao's avatar
Ke Bao committed
571
572
        # The prefix length of the last prefix matching
        self.last_matched_prefix_len: int = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
573

574
575
576
577
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
578

579
580
581
        # For retraction
        self.is_retracted = False

582
583
584
585
586
587
588
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

589
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
590
        self.return_logprob = return_logprob
591
        # Start index to compute logprob from.
592
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
593
        self.top_logprobs_num = top_logprobs_num
594
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
595
596
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
597

598
        # Logprobs (return values)
599
600
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
601
602
603
604
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
605
606
607
608
609
610
611
612
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614

        if return_logprob:
615
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
618
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
621
622
623
624
            # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
            self.output_token_ids_logprobs_val: List[
                Union[List[float], torch.Tensor]
            ] = []
625
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
626
627
628
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
629
630
631
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
632
        self.hidden_states: List[List[float]] = []
633
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
634
635
        self.output_topk_p = None
        self.output_topk_index = None
636

637
        # Embedding (return values)
638
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
639

640
        # Constrained decoding
641
        self.grammar: Optional[BaseGrammarObject] = None
642
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
643

644
        # The number of cached tokens that were already cached in the KV cache
645
        self.cached_tokens = 0
646
        self.already_computed = 0
647

648
649
650
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
651
652

        # For metrics
653
        self.metrics_collector = metrics_collector
654
        self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
655
        self.has_log_time_stats: bool = False
656
        self.last_tic = time.monotonic()
657

Byron Hsu's avatar
Byron Hsu committed
658
        # For disaggregation
659
        self.bootstrap_host: str = bootstrap_host
660
        self.bootstrap_port: Optional[int] = bootstrap_port
661
        self.bootstrap_room: Optional[int] = bootstrap_room
662
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
663

664
665
666
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
667
668
669
670
671
672
673
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

674
675
676
677
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
678
        self.metadata_buffer_index: int = -1
679

680
681
682
683
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

684
685
686
    @property
    def is_prefill_only(self) -> bool:
        """Check if this request is prefill-only (no token generation needed)."""
687
        # NOTE: when spec is enabled, prefill_only optimizations are disabled
688
        from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
689

690
691
692
693
        spec_alg = global_server_args_dict["speculative_algorithm"]
        return self.sampling_params.max_new_tokens == 0 and (
            spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
        )
694

695
696
697
    def add_latency(self, stage: RequestStage):
        if self.metrics_collector is None:
            return
698

699
        now = time.monotonic()
700
        self.metrics_collector.observe_per_stage_req_latency(
701
702
703
704
            stage.value, now - self.last_tic
        )
        self.last_tic = now

705
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
706
707
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
708
        else:
Mick's avatar
Mick committed
709
            self.multimodal_inputs.merge(image_inputs)
710

711
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
712
        # Whether request reached finished condition
713
714
        return self.finished_reason is not None

715
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
716
        self.fill_ids = self.origin_input_ids + self.output_ids
717
718
719
720
721
722
723
724
        input_len = len(self.fill_ids)
        # NOTE: the matched length is at most 1 less than the input length to enable logprob computation
        max_prefix_len = input_len - 1
        if self.return_logprob:
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
        max_prefix_len = max(max_prefix_len, 0)
        token_ids = self.fill_ids[:max_prefix_len]

725
        if tree_cache is not None:
726
727
728
729
730
731
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
732
733
734
735
736
737
                key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
                **(
                    {"req": self, "cow_mamba": True}
                    if isinstance(tree_cache, MambaRadixCache)
                    else {}
                ),
738
            )
Ke Bao's avatar
Ke Bao committed
739
            self.last_matched_prefix_len = len(self.prefix_indices)
740
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
741

Liangsheng Yin's avatar
Liangsheng Yin committed
742
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
743
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
744
745
746
747
748
749
750
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )
751
752
753
754
755
756
757
            self.surr_and_decode_ids = (
                self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
            )
            self.cur_decode_ids_len = len(self.output_ids)
        else:
            self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
            self.cur_decode_ids_len = len(self.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
758

759
        return self.surr_and_decode_ids, self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
760

ybyang's avatar
ybyang committed
761
    def tail_str(self) -> str:
762
763
764
765
766
767
768
769
770
771
772
        # Check stop strings and stop regex patterns together
        if (
            len(self.sampling_params.stop_strs) > 0
            or len(self.sampling_params.stop_regex_strs) > 0
        ):
            max_len_tail_str = max(
                self.sampling_params.stop_str_max_len + 1,
                self.sampling_params.stop_regex_max_len + 1,
            )

        tail_len = min((max_len_tail_str + 1), len(self.output_ids))
ybyang's avatar
ybyang committed
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        return self.tokenizer.decode(self.output_ids[-tail_len:])

    def check_match_stop_str_prefix(self) -> bool:
        """
        Check if the suffix of tail_str overlaps with any stop_str prefix
        """
        if not self.sampling_params.stop_strs:
            return False

        tail_str = self.tail_str()

        # Early return if tail_str is empty
        if not tail_str:
            return False

        for stop_str in self.sampling_params.stop_strs:
            if not stop_str:
                continue
            # Check if stop_str is contained in tail_str (fastest check first)
            if stop_str in tail_str:
                return True

            # Check if tail_str suffix matches stop_str prefix
            # Only check if stop_str is not empty, it's for stream output
            min_len = min(len(tail_str), len(stop_str))
            for i in range(1, min_len + 1):
                if tail_str[-i:] == stop_str[:i]:
                    return True

        return False

804
    def check_finished(self):
805
        if self.finished():
806
807
            return

808
        if self.to_abort:
809
810
811
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
812
813
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
814
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
815
816
817
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
818
819
            return

820
821
822
823
824
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

825
        last_token_id = self.output_ids[-1]
826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
844

845
846
847
848
849
850
851
852
        if last_token_id > self.vocab_size or last_token_id < 0:
            if self.sampling_params.stop_token_ids:
                self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
            if self.eos_token_ids:
                self.output_ids[-1] = next(iter(self.eos_token_ids))
            self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
            return

853
854
855
856
        if (
            len(self.sampling_params.stop_strs) > 0
            or len(self.sampling_params.stop_regex_strs) > 0
        ):
ybyang's avatar
ybyang committed
857
            tail_str = self.tail_str()
858

859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            # Check stop strings
            if len(self.sampling_params.stop_strs) > 0:
                for stop_str in self.sampling_params.stop_strs:
                    if stop_str in tail_str or stop_str in self.decoded_text:
                        self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
                        return

            # Check stop regex
            if len(self.sampling_params.stop_regex_strs) > 0:
                for stop_regex_str in self.sampling_params.stop_regex_strs:
                    if re.search(stop_regex_str, tail_str):
                        self.finished_reason = FINISHED_MATCHED_REGEX(
                            matched=stop_regex_str
                        )
                        return
874

875
    def reset_for_retract(self):
876
        self.prefix_indices = torch.empty((0,), dtype=torch.int64)
877
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
878
        self.swa_uuid_for_lock = None
879
880
        self.extend_input_len = 0
        self.is_retracted = True
881
882
883
884
885
886
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
887
        self.mamba_pool_idx = None
888
        self.already_computed = 0
889

Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
892
893
894
895
896
897
898
899
900
901
902
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

903
904
905
906
907
908
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
909
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
910
        else:
911
912
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
        logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
913
914
        self.has_log_time_stats = True

915
916
917
918
919
920
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
921
        self.return_logprob = False
922
923
924
925
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
926
    def __repr__(self):
927
        return (
928
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
929
930
931
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
932
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
933
934


935
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
936
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
937
    """Store all information of a batch on the scheduler."""
938

939
    # Request, memory pool, and cache
940
    reqs: List[Req]
941
    req_to_token_pool: ReqToTokenPool = None
942
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
943
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
944
    is_hybrid: bool = False
945

946
    # Batch configs
947
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
948
    forward_mode: ForwardMode = None
949
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
950
951
952
953
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
954

955
956
957
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

958
    # Sampling info
959
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
960

961
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
962
    input_ids: torch.Tensor = None  # shape: [b], int64
963
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
964
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
965
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
966
    seq_lens: torch.Tensor = None  # shape: [b], int64
967
    seq_lens_cpu: torch.Tensor = None  # shape: [b], int64
968
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
969
970
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
971

972
973
974
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

975
976
    # The sum of all sequence lengths
    seq_lens_sum: int = None
977
978
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: torch.Tensor = None  # shape: [b], int32
979

Ke Bao's avatar
Ke Bao committed
980
981
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
982
    global_num_tokens_for_logprob: Optional[List[int]] = None
983
    is_extend_in_batch: bool = False
984
    can_run_dp_cuda_graph: bool = False
985
986
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
987

988
    # For processing logprobs
989
    return_logprob: bool = False
990
    top_logprobs_nums: Optional[List[int]] = None
991
    token_ids_logprobs: Optional[List[List[int]]] = None
992

Lianmin Zheng's avatar
Lianmin Zheng committed
993
994
995
996
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

997
998
999
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
1000
    extend_num_tokens: Optional[int] = None
1001
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1002
    extend_logprob_start_lens: List[int] = None
1003
1004
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
1005

Lianmin Zheng's avatar
Lianmin Zheng committed
1006
    # For encoder-decoder architectures
1007
1008
1009
1010
1011
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

1012
1013
1014
    # Stream
    has_stream: bool = False

1015
1016
    # Has grammar
    has_grammar: bool = False
1017

1018
    # Device
1019
1020
    device: str = "cuda"

1021
    # Speculative decoding
1022
    spec_algorithm: SpeculativeAlgorithm = None
1023
1024
    # spec_info: Optional[SpecInput] = None
    spec_info: Optional[SpecInput] = None
1025

1026
1027
1028
    # Whether to return hidden states
    return_hidden_states: bool = False

1029
1030
1031
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

1032
    # hicache pointer for synchronizing data loading from CPU to GPU
1033
    hicache_consumer_index: int = -1
1034

1035
    @classmethod
1036
1037
    def init_new(
        cls,
1038
        reqs: List[Req],
1039
        req_to_token_pool: ReqToTokenPool,
1040
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
1041
1042
1043
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
1044
        spec_algorithm: SpeculativeAlgorithm,
1045
        chunked_req: Optional[Req] = None,
1046
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1047
1048
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
1049
1050
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
1051
1052
1053
1054
            assert (
                tree_cache is None
                or isinstance(tree_cache, SWARadixCache)
                or isinstance(tree_cache, SWAChunkCache)
Hanming Lu's avatar
Hanming Lu committed
1055
1056
1057
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

1058
1059
1060
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
1061
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
1062
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
1063
            is_hybrid=is_hybrid,
1064
            model_config=model_config,
1065
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
1066
            return_logprob=return_logprob,
1067
            has_stream=any(req.stream for req in reqs),
1068
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
1069
            device=req_to_token_pool.device,
1070
            spec_algorithm=spec_algorithm,
1071
            return_hidden_states=any(req.return_hidden_states for req in reqs),
1072
            is_prefill_only=all(req.is_prefill_only for req in reqs),
1073
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
1074
1075
        )

1076
    def batch_size(self):
1077
        return len(self.reqs)
1078

Lianmin Zheng's avatar
Lianmin Zheng committed
1079
1080
1081
    def is_empty(self):
        return len(self.reqs) == 0

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
    def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
        if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
            mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
            if mamba_available_size < num_reqs:
                if self.tree_cache is not None and isinstance(
                    self.tree_cache, MambaRadixCache
                ):
                    mamba_num = max(0, num_reqs - mamba_available_size)
                    self.tree_cache.evict_mamba(mamba_num)
            req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
        else:
            req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
            )
        return req_pool_indices

1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
    def allocate_for_eagle_v2(self):
        from sglang.srt.speculative.eagle_info import EagleDraftInput
        from sglang.srt.speculative.spec_utils import assign_req_to_token_pool

        bs = self.batch_size()

        assert self.spec_info.is_draft_input()
        draft_input: EagleDraftInput = self.spec_info

        # FIXME(lsyin): now implementation does not enable over-allocation
        # Now seq_lens and allocate_lens are correct
        self.maybe_wait_verify_done()

        new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
        num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
        out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)

        assign_req_to_token_pool[(bs,)](
            self.req_pool_indices,
            self.req_to_token_pool.req_to_token,
            draft_input.allocate_lens,
            new_allocate_lens,
            out_cache_loc,
            self.req_to_token_pool.req_to_token.shape[1],
            next_power_of_2(bs),
        )
        draft_input.allocate_lens = new_allocate_lens

        # FIXME(lsyin): remove seq_lens_sum calculation
        self.seq_lens_cpu = self.seq_lens.cpu()
        self.seq_lens_sum = self.seq_lens_cpu.sum().item()

1135
1136
1137
1138
1139
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1140
            im = req.multimodal_inputs
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1152
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1165
                # NOTE: the encoder part should be considered as a whole
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1183
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1184
1185
            self.device, non_blocking=True
        )
1186
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1187
1188
            self.device, non_blocking=True
        )
1189
        self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1190
1191

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1192
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1193
1194
1195
1196
1197
1198
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1199
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1200
1201
1202
1203
1204
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1205
1206
1207
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1208

1209
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1210
1211
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1212
        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1213
        reqs = self.reqs
1214
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1215
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1216
        seq_lens = [len(r.fill_ids) for r in reqs]
1217
        orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1218
1219
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1220

woodx's avatar
woodx committed
1221
1222
1223
1224
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

1225
1226
1227
        input_ids_tensor = torch.tensor(
            list(chain.from_iterable(input_ids)), dtype=torch.int64
        ).to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1228
1229
1230
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1231
        seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1232
1233
1234
        orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
woodx's avatar
woodx committed
1235
1236
1237
1238
1239
1240
1241

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

1242
1243
1244
1245
1246
1247
        # Set batch fields needed by alloc_for_extend
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
        self.seq_lens = seq_lens_tensor
        self.seq_lens_cpu = seq_lens_cpu
        self.extend_num_tokens = extend_num_tokens
1248
1249

        # Allocate memory
1250
1251
        out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
            self
1252
1253
1254
        )

        # Set fields
Rin Intachuen's avatar
Rin Intachuen committed
1255
        input_embeds = []
1256
        extend_input_logprob_token_ids = []
1257
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1258

Lianmin Zheng's avatar
Lianmin Zheng committed
1259
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1260
            req.req_pool_idx = req_pool_indices[i]
1261
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1262

Rin Intachuen's avatar
Rin Intachuen committed
1263
1264
1265
1266
1267
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1268
1269
            multimodal_inputs.append(req.multimodal_inputs)

1270
1271
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1272
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1273

1274
            # Compute the relative logprob_start_len in an extend batch
1275
1276
1277
1278
1279
1280
1281
1282
            #
            # Key variables:
            # - logprob_start_len: Absolute position in full sequence where logprob computation begins
            # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
            # - extend_input_len: Number of tokens that need to be processed in this extend batch
            #   (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
            #    and prefix_indices are the cached/shared prefix tokens)
            #
1283
            if req.logprob_start_len >= pre_len:
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
                # Optimization for prefill-only requests: When we only need logprobs at
                # positions beyond the input sequence (to score next-token likelihood), skip all
                # input logprob computation during prefill since no generation will occur.
                if self.is_prefill_only and req.logprob_start_len == len(
                    req.origin_input_ids
                ):
                    # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
                    req.extend_logprob_start_len = req.extend_input_len
                else:
                    # Convert absolute logprob_start_len to relative extend_logprob_start_len
                    #
                    # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
                    # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
                    # This means: "compute logprobs from position 3 onwards in extend batch"
                    req.extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len,
                        req.extend_input_len,
                        req.seqlen - 1,
                    )
1303
            else:
1304
                # logprob_start_len is before the current extend batch, so start from beginning
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1351

Lianmin Zheng's avatar
Lianmin Zheng committed
1352
1353
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
1354
        self.orig_seq_lens = orig_seq_lens_tensor
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1356
1357
1358
1359
1360
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1361
1362
1363
1364
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1365
                pixel_values = getattr(mm_item, "feature", None)
1366
                if isinstance(pixel_values, torch.Tensor):
1367
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1368
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1369
        self.token_type_ids = token_type_ids_tensor
1370
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1371

1372
1373
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1374
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1375

1376
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
1377
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1378

1379
1380
1381
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1382
        # Build sampling info
1383
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1384
1385
            self,
            self.model_config.vocab_size,
1386
        )
1387

1388
1389
1390
1391
1392
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1393
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1394
        self.forward_mode = ForwardMode.MIXED
1395
        running_bs = running_batch.batch_size()
1396
1397
1398
1399
1400

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1401
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1402
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1403

1404
        self.merge_batch(running_batch)
1405
1406
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1407

1408
1409
1410
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1411
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1412
        self.prefix_lens.extend(
1413
            [
1414
                len(r.origin_input_ids) + len(r.output_ids) + delta
1415
1416
1417
                for r in running_batch.reqs
            ]
        )
1418
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1419
1420
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1421
        self.extend_logprob_start_lens.extend([0] * running_bs)
1422

1423
    def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1424
        page_size = self.token_to_kv_pool_allocator.page_size
1425
1426
1427
1428
1429
        requests = (
            self.reqs
            if selected_indices is None
            else [self.reqs[i] for i in selected_indices]
        )
1430
        if page_size == 1:
1431
            return len(requests)
1432
1433
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1434
        return (
1435
            sum(1 for req in requests if req.seqlen % page_size == 0)
pansicheng's avatar
pansicheng committed
1436
            if self.enable_overlap
1437
            else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
pansicheng's avatar
pansicheng committed
1438
        )
1439

1440
1441
1442
    def check_decode_mem(
        self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
    ):
Hanming Lu's avatar
Hanming Lu committed
1443
        num_tokens = (
1444
            self.new_page_count_next_decode(selected_indices)
1445
1446
1447
1448
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1449
1450
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1451

1452
    def retract_decode(self, server_args: ServerArgs):
1453
        """Retract the decoding requests when there is not enough memory."""
1454
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1455
1456

        # TODO(lsyin): improve retraction policy for radix cache
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1470
1471
        retracted_reqs = []
        first_iter = True
1472
1473
        while first_iter or (
            not self.check_decode_mem(selected_indices=sorted_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
1474
1475
1476
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1491
1492
                break

1493
            first_iter = False
1494
1495
1496
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)
1497
            self.release_req(idx, len(sorted_indices), server_args)
Liangsheng Yin's avatar
Liangsheng Yin committed
1498

1499
1500
1501
1502
1503
1504
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1505
        self.filter_batch(keep_indices=sorted_indices)
1506

Liangsheng Yin's avatar
Liangsheng Yin committed
1507
1508
1509
1510
1511
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
1512
1513
            total_decoded_tokens
            + envs.SGLANG_RETRACT_DECODE_STEPS.get() * len(self.reqs)
Liangsheng Yin's avatar
Liangsheng Yin committed
1514
1515
1516
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

1517
        return retracted_reqs, new_estimate_ratio, []
1518

1519
1520
    def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
        req = self.reqs[idx]
1521
        seq_lens_cpu = self.seq_lens_cpu.numpy()
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551

        if server_args.disaggregation_mode == "decode":
            req.offload_kv_cache(
                self.req_to_token_pool, self.token_to_kv_pool_allocator
            )
        if isinstance(self.tree_cache, ChunkCache):
            # ChunkCache does not have eviction
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)
        else:
            # TODO: apply more fine-grained retraction
            last_uncached_pos = (
                len(req.prefix_indices) // server_args.page_size
            ) * server_args.page_size
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)

            # release the last node
            if self.is_hybrid:
                self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
            else:
                self.tree_cache.dec_lock_ref(req.last_node)

            # NOTE(lsyin): we should use the newly evictable memory instantly.
1552
            num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
1553
1554
1555
1556
            self._evict_tree_cache_if_needed(num_tokens)

        req.reset_for_retract()

1557
1558
1559
1560
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1561
1562
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1563
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1564
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1565
        self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1566
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1567
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1568
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1569
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1570
        self.extend_num_tokens = 0
1571
1572
1573
1574
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1575

1576
1577
1578
1579
1580
    @property
    def is_v2_eagle(self):
        # FIXME: finally deprecate is_v2_eagle
        return self.enable_overlap and self.spec_algorithm.is_eagle()

1581
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1582
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1583
1584
        bs = len(self.reqs)

1585
1586
1587
1588
1589
        if self.is_v2_eagle:
            # FIXME(lsyin): make this sync optional
            self.allocate_for_eagle_v2()

        if not self.spec_algorithm.is_none():
1590
1591
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1592
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1593

1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1617
        # Update fields
1618
1619
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1620

1621
1622
1623
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_decode()

1624
1625
1626
1627
        # Allocate memory
        self.out_cache_loc = alloc_for_decode(self, token_per_req=1)

        # Update seq_lens after allocation
1628
        if self.enable_overlap:
1629
1630
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
1631
            self.seq_lens_cpu = self.seq_lens_cpu + 1
1632
            self.orig_seq_lens = self.orig_seq_lens + 1
1633
1634
1635
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1636
            self.seq_lens_cpu.add_(1)
1637
            self.orig_seq_lens.add_(1)
1638
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1639

1640
1641
1642
1643
1644
1645
1646
1647
    def maybe_wait_verify_done(self):
        if self.is_v2_eagle:
            from sglang.srt.speculative.eagle_info import EagleDraftInput

            draft_input: EagleDraftInput = self.spec_info
            if draft_input.verify_done is not None:
                draft_input.verify_done.synchronize()

1648
1649
    def filter_batch(
        self,
1650
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1651
1652
        keep_indices: Optional[List[int]] = None,
    ):
1653
1654
1655
1656
        # FIXME(lsyin): used here to get the correct seq_lens
        # The batch has been launched but we need it verified to get correct next batch info
        self.maybe_wait_verify_done()

1657
        if keep_indices is None:
1658
1659
1660
1661
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1662
1663
1664
            keep_indices = [
                i
                for i in range(len(self.reqs))
1665
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1666
                and self.reqs[i] not in chunked_req_to_exclude
1667
1668
1669
            ]

        if keep_indices is None or len(keep_indices) == 0:
1670
1671
1672
1673
            # Filter out all requests
            self.reqs = []
            return

1674
        if len(keep_indices) == len(self.reqs):
1675
1676
1677
            # No need to filter
            return

1678
1679
1680
1681
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1682
        if self.model_config.is_encoder_decoder:
1683
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1684
1685
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1686
        self.reqs = [self.reqs[i] for i in keep_indices]
1687
1688
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1689
1690
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1691
        self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1692
        self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1693
        self.out_cache_loc = None
1694
        self.seq_lens_sum = self.seq_lens.sum().item()
1695
        self.output_ids = self.output_ids[keep_indices_device]
1696
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1697
        if self.return_logprob:
1698
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1699
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1700
1701
        else:
            self.top_logprobs_nums = None
1702
            self.token_ids_logprobs = None
1703

1704
        self.has_stream = any(req.stream for req in self.reqs)
1705
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1706

1707
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1708
        if self.spec_info:
1709
1710
1711
1712
1713
1714
1715
1716
            if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
                has_been_filtered = False
            else:
                has_been_filtered = True
            self.spec_info.filter_batch(
                new_indices=keep_indices_device,
                has_been_filtered=has_been_filtered,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1717

1718
    def merge_batch(self, other: "ScheduleBatch"):
1719
1720
1721
1722
        # NOTE: in v2 eagle mode, we do not need wait verify here because
        # 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
        # 2) other batch is always decode, which is finished in previous step

1723
1724
1725
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1726
        self.sampling_info.merge_batch(other.sampling_info)
1727

1728
1729
1730
1731
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1732
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
1734
            [self.req_pool_indices, other.req_pool_indices]
        )
1735
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1736
        self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1737
        self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1738
        self.out_cache_loc = None
1739
        self.seq_lens_sum += other.seq_lens_sum
1740
        if self.output_ids is not None:
1741
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1742
1743
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1744
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1745
1746
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1747
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1748
1749
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1750
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1751
        self.reqs.extend(other.reqs)
1752
1753
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1754

1755
1756
1757
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1758
        self.return_hidden_states |= other.return_hidden_states
1759

1760
1761
1762
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1763
1764
1765
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1766
        if self.forward_mode.is_decode_or_idle():
1767
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1768
1769
1770
1771
1772
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1773
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1774
1775
1776
1777
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1778

Lianmin Zheng's avatar
Lianmin Zheng committed
1779
        seq_lens_cpu = (
1780
            seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
Lianmin Zheng's avatar
Lianmin Zheng committed
1781
1782
        )

1783
1784
1785
1786
1787
        return ModelWorkerBatch(
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
1788
            orig_seq_lens=self.orig_seq_lens,
1789
            out_cache_loc=self.out_cache_loc,
1790
            seq_lens_cpu=seq_lens_cpu,
1791
            seq_lens_sum=self.seq_lens_sum,
1792
1793
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1794
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1795
            global_num_tokens=self.global_num_tokens,
1796
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1797
            is_extend_in_batch=self.is_extend_in_batch,
1798
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1799
1800
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1801
            extend_num_tokens=self.extend_num_tokens,
1802
1803
1804
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1805
            multimodal_inputs=self.multimodal_inputs,
1806
1807
1808
1809
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1810
            lora_ids=[req.lora_id for req in self.reqs],
1811
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1812
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1813
            token_type_ids=self.token_type_ids,
1814
1815
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1816
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1817
            capture_hidden_mode=(
1818
                CaptureHiddenMode.FULL
1819
                if self.return_hidden_states
1820
1821
1822
1823
1824
1825
1826
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1827
            ),
1828
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1829
            is_prefill_only=self.is_prefill_only,
1830
1831
        )

1832
    def copy(self):
1833
        # Only contain fields that will be used by process_batch_result
1834
1835
        return ScheduleBatch(
            reqs=self.reqs,
1836
            model_config=self.model_config,
1837
            forward_mode=self.forward_mode,
1838
1839
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1840
            decoding_reqs=self.decoding_reqs,
1841
            spec_algorithm=self.spec_algorithm,
1842
1843
1844
1845
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1846
            is_prefill_only=self.is_prefill_only,
1847
1848
            seq_lens_cpu=self.seq_lens_cpu,
            enable_overlap=self.enable_overlap,
1849
1850
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1851
1852
    def _evict_tree_cache_if_needed(self, num_tokens: int):
        if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
Hanming Lu's avatar
Hanming Lu committed
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

1878
1879
    def __str__(self):
        return (
1880
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1881
1882
1883
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1884

1885
@dataclasses.dataclass
1886
1887
1888
1889
class ModelWorkerBatch:
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1890
    input_ids: torch.Tensor
1891
1892
1893
1894
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1895
    # The indices of output tokens in the token_to_kv_pool_allocator
1896
    out_cache_loc: torch.Tensor
1897
1898
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1899
1900
    seq_lens_sum: int

1901
1902
1903
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1904
    token_ids_logprobs: Optional[List[List[int]]]
1905

Ke Bao's avatar
Ke Bao committed
1906
1907
    # For DP attention
    global_num_tokens: Optional[List[int]]
1908
    global_num_tokens_for_logprob: Optional[List[int]]
1909
    is_extend_in_batch: bool
1910
    can_run_dp_cuda_graph: bool
1911
1912
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1913

1914
    # For extend
1915
    extend_num_tokens: Optional[int]
1916
1917
1918
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1919
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1920
1921

    # For multimodal
Mick's avatar
Mick committed
1922
    multimodal_inputs: Optional[List[MultimodalInputs]]
1923

1924
1925
1926
1927
1928
1929
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1930
    # For LoRA
1931
    lora_ids: Optional[List[str]]
1932
1933
1934

    # Sampling info
    sampling_info: SamplingBatchInfo
1935

1936
1937
1938
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: Optional[torch.Tensor] = None

Rin Intachuen's avatar
Rin Intachuen committed
1939
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1940
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1941

woodx's avatar
woodx committed
1942
1943
1944
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1945
    # Speculative decoding
1946
    spec_algorithm: SpeculativeAlgorithm = None
1947
1948
1949

    spec_info: Optional[SpecInput] = None

1950
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1951
    capture_hidden_mode: CaptureHiddenMode = None
1952
    hicache_consumer_index: int = -1
1953

1954
1955
    # Overlap scheduler related
    delay_sample_launch: bool = False
1956

1957
1958
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False