schedule_batch.py 74.3 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
36
import logging
37
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from enum import Enum, auto
39
from http import HTTPStatus
40
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
43
import torch
44
45
import triton
import triton.language as tl
46

Liangsheng Yin's avatar
Liangsheng Yin committed
47
from sglang.global_config import global_config
48
from sglang.srt.configs.model_config import ModelConfig
49
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
51
52
53
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
54
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Hanming Lu's avatar
Hanming Lu committed
55
56
57
58
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
59
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
60
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
Hanming Lu's avatar
Hanming Lu committed
62
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
63
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
64
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
65
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
66
from sglang.srt.sampling.sampling_params import SamplingParams
67
from sglang.srt.server_args import ServerArgs
68
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
69

70
if TYPE_CHECKING:
71
72
73
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
74
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
75

76
77
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
78
    "mm_attention_backend",
79
80
81
82
83
84
85
86
87
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
    "disable_radix_cache",
    "enable_dp_attention",
    "enable_two_batch_overlap",
    "enable_dp_lm_head",
88
89
    "enable_deepep_moe",
    "deepep_mode",
90
    "enable_ep_moe",
91
    "enable_flashinfer_moe",
92
    "enable_flashinfer_allreduce_fusion",
93
94
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
95
    "deepep_config",
96
    "ep_num_redundant_experts",
97
98
99
100
101
102
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
103
    "speculative_accept_threshold_acc",
104
105
    "torchao_config",
    "triton_attention_reduce_in_fp32",
106
    "num_reserved_decode_tokens",
107
    "weight_loader_disable_mmap",
Yuan Luo's avatar
Yuan Luo committed
108
    "enable_triton_kernel_moe",
109
110
]

111
# Put some global args for easy access
112
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
113

Ying Sheng's avatar
Ying Sheng committed
114
115
116
logger = logging.getLogger(__name__)


117
118
119
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
120

121
    def to_json(self):
122
        raise NotImplementedError()
123
124
125


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
126
    def __init__(self, matched: Union[int, List[int]]):
127
128
129
        super().__init__()
        self.matched = matched

130
131
132
133
134
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
135
136


137
138
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
139
        super().__init__()
140
        self.matched = matched
141

142
143
144
145
146
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
147
148


149
150
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
151
        super().__init__()
152
        self.length = length
153

154
155
156
157
158
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
159
160
161


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
162
    def __init__(self, message=None, status_code=None, err_type=None):
163
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
164
        self.message = message or "Aborted"
165
166
        self.status_code = status_code
        self.err_type = err_type
167

168
169
170
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
171
            "message": self.message,
172
173
            "status_code": self.status_code,
            "err_type": self.err_type,
174
        }
175

Lianmin Zheng's avatar
Lianmin Zheng committed
176

Mick's avatar
Mick committed
177
178
179
180
181
182
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

183
184
185
186
187
188
189
190
191
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

192
193
194
195
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
196

197
@dataclasses.dataclass
Mick's avatar
Mick committed
198
199
class MultimodalDataItem:
    """
200
201
202
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
203
204

    We put the common fields first and the model-specific fields last.
Mick's avatar
Mick committed
205
    """
206

Mick's avatar
Mick committed
207
208
209
    modality: Modality
    hash: int = None
    pad_value: int = None
210
    offsets: Optional[list] = None
211
212
213
214
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None

    image_sizes: Tuple[int, int] = None
Mick's avatar
Mick committed
215

216
217
218
219
220
    audio_feature_lens: Optional[List[torch.Tensor]] = None
    audio_offsets: Optional[List[Tuple[int, int]]] = None
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

    # For qwen-vl
221
    image_grid_thw: Union[torch.Tensor, np.ndarray] = None
222
    second_per_grid_ts: Optional[List[torch.Tensor]] = None
Mick's avatar
Mick committed
223

224
    # For deepseek-vl
Mick's avatar
Mick committed
225
226
227
    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None

228
    # For minicpmv
Mick's avatar
Mick committed
229
230
231
    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

232
233
234
    # For mllama
    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
235

236
237
    # For kimi-vl
    image_grid_hws: Optional[List[torch.Tensor]] = None
Mick's avatar
Mick committed
238

239
    # For gemma3n
240
241
    input_features_mask: Optional[torch.Tensor] = None

242
243
244
245
    # For phi4-mm
    image_attention_mask: Optional[torch.Tensor] = None
    audio_attention_mask: Optional[torch.Tensor] = None

Mick's avatar
Mick committed
246
247
248
249
250
251
252
253
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
254
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
255
        """
256
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
257

258
        if self.hash is None:
259
260
            if self.feature is not None:
                hashed_feature = self.feature
261
            else:
262
263
                hashed_feature = self.precomputed_features
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
264
265
266
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

267
268
269
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
270
    def is_audio(self):
271
272
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
273
            or not MultimodalDataItem.is_empty_list(self.feature)
274
        )
Mick's avatar
Mick committed
275
276
277

    def is_image(self):
        return (
278
            self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
279
280
        ) and (
            self.precomputed_features is not None
281
            or not MultimodalDataItem.is_empty_list(self.feature)
282
        )
Mick's avatar
Mick committed
283
284

    def is_video(self):
285
286
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
287
            or not MultimodalDataItem.is_empty_list(self.feature)
288
        )
Mick's avatar
Mick committed
289

290
291
292
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
293
294
295
296
    def validate(self):
        ...
        # TODO

297
298
299
300
301
302
303
304
305
306
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

307
    def merge(self, other):
308
        self.feature += other.feature
309
310
311
312
313
        self.image_sizes += other.image_sizes
        self.image_offsets += other.image_offsets
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
314
315
316
317
318
319
320

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
321
    image_pad_len: Optional[list] = None
322
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
323

Mick's avatar
Mick committed
324
    # image
Mick's avatar
Mick committed
325
    im_token_id: Optional[int] = None
326
327
328
329
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
330
331
332

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
333

Mick's avatar
Mick committed
334
    # audio
335
336
337
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
338

339
340
341
342
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
343
    @staticmethod
344
    def from_dict(obj: dict):
Mick's avatar
Mick committed
345
        ret = MultimodalInputs(
Mick's avatar
Mick committed
346
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
347
        )
348

Mick's avatar
Mick committed
349
        assert isinstance(ret.mm_items, list)
350
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
351
352
        for item in ret.mm_items:
            item.set_pad_value()
353
354

        optional_args = [
355
356
            "mrope_positions",
            "mrope_position_delta",
357
            "im_token_id",
Mick's avatar
Mick committed
358
359
            "im_start_id",
            "im_end_id",
360
            "video_token_id",
Mick's avatar
Mick committed
361
362
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
363
364
            "audio_start_id",
            "audio_end_id",
365
            "audio_token_id",
366
367
368
369
370
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
371
372
        return ret

Mick's avatar
Mick committed
373
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
374
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
375

376
377
378
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
379
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
380
381
        return any(item.is_audio() for item in self.mm_items)

382
383
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
384
385

    def merge(self, other: MultimodalInputs):
386
387
388
        """
        merge image inputs when requests are being merged
        """
389

390
        # args needed to be merged
391
        optional_args = [
Mick's avatar
Mick committed
392
            "mm_items",
393
            "image_pad_len",
394
395
        ]
        for arg in optional_args:
396
397
398
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
399
400
401
402
403
404
405
406
407
408

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

409
410
411
412
413
414
415
416
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
417
418
419
420
421
422

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
423
        # other args would be kept intact
424

Liangsheng Yin's avatar
Liangsheng Yin committed
425

Lianmin Zheng's avatar
Lianmin Zheng committed
426
class Req:
427
    """The input and output status of a request."""
428

429
430
431
432
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
433
        origin_input_ids: List[int],
434
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
437
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        stream: bool = False,
439
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
440
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
441
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
442
        token_type_ids: List[int] = None,
443
        session_id: Optional[str] = None,
444
        custom_logit_processor: Optional[str] = None,
445
        return_hidden_states: bool = False,
446
        eos_token_ids: Optional[Set[int]] = None,
447
        bootstrap_host: Optional[str] = None,
448
        bootstrap_port: Optional[int] = None,
449
        bootstrap_room: Optional[int] = None,
450
        data_parallel_rank: Optional[int] = None,
451
    ):
452
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
453
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
454
        self.origin_input_text = origin_input_text
455
456
457
458
459
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
460
        self.origin_input_ids = origin_input_ids
461
462
463
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
464
        self.fill_ids = []
465
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
466
        self.input_embeds = input_embeds
467

woodx's avatar
woodx committed
468
469
470
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
471
472
473
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
474
        # Sampling info
475
476
477
478
479
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
480
        self.sampling_params = sampling_params
481
        self.custom_logit_processor = custom_logit_processor
482
        self.return_hidden_states = return_hidden_states
483
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
484

485
        # Memory pool info
486
        self.req_pool_idx: Optional[int] = None
487

488
489
490
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
491
492
        # Whether this request has finished output
        self.finished_output = None
493
494
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
495
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
496
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
497
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
498
        self.stream = stream
499
        self.eos_token_ids = eos_token_ids
500

501
        # For incremental decoding
502
503
504
505
506
507
508
509
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
510
511
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
512
        self.decoded_text = ""
513

514
        # For multimodal inputs
Mick's avatar
Mick committed
515
        self.multimodal_inputs: Optional[MultimodalInputs] = None
516

517
        # Prefix info
518
        # The indices to kv cache for the shared prefix.
519
        self.prefix_indices: torch.Tensor = []
520
        # Number of tokens to run prefill.
521
        self.extend_input_len = 0
522
523
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
524
525
526
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
527
528
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
529

530
531
532
533
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
534

535
536
537
        # For retraction
        self.is_retracted = False

538
539
540
541
542
543
544
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

545
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
546
        self.return_logprob = return_logprob
547
        # Start index to compute logprob from.
548
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
549
        self.top_logprobs_num = top_logprobs_num
550
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
551
552
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
553

554
        # Logprobs (return values)
555
556
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
557
558
559
560
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
561
562
563
564
565
566
567
568
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
569
570

        if return_logprob:
571
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
572
573
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
574
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
575
576
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
577
578
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
579
580
581
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
582
583
584
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
585
        self.hidden_states: List[List[float]] = []
586
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
587

588
        # Embedding (return values)
589
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
590

591
        # Constrained decoding
592
        self.grammar: Optional[BaseGrammarObject] = None
593
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
594

595
        # The number of cached tokens that were already cached in the KV cache
596
        self.cached_tokens = 0
597
        self.already_computed = 0
598

599
600
601
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
602
603
604
605
606
607

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
608

Byron Hsu's avatar
Byron Hsu committed
609
        # For disaggregation
610
        self.bootstrap_host: str = bootstrap_host
611
        self.bootstrap_port: Optional[int] = bootstrap_port
612
        self.bootstrap_room: Optional[int] = bootstrap_room
613
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
614

615
616
617
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
618
619
620
621
622
623
624
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

625
626
627
628
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
629
        self.metadata_buffer_index: int = -1
630

631
632
633
634
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

635
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
636
637
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
638
        else:
Mick's avatar
Mick committed
639
            self.multimodal_inputs.merge(image_inputs)
640

641
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
642
        # Whether request reached finished condition
643
644
        return self.finished_reason is not None

645
646
647
648
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
    ):
649
        self.fill_ids = self.origin_input_ids + self.output_ids
650
        if tree_cache is not None:
651
652
653
654
655
656
657
658
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
                key=self.adjust_max_prefix_ids(),
            )
659
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
660

661
    def adjust_max_prefix_ids(self):
662
663
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
664
665
666
667

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
668
669
670
671
672

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

673
        if self.return_logprob:
674
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
675

676
        max_prefix_len = max(max_prefix_len, 0)
677
        return self.fill_ids[:max_prefix_len]
678

Liangsheng Yin's avatar
Liangsheng Yin committed
679
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
680
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
681
682
683
684
685
686
687
688
689
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
690
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
691

692
    def check_finished(self):
693
        if self.finished():
694
695
            return

696
        if self.to_abort:
697
698
699
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
700
701
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
702
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
703
704
705
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
706
707
            return

708
709
710
711
712
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

713
        last_token_id = self.output_ids[-1]
714

715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
732

733
        # Check stop strings
734
735
736
737
738
739
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
740
                if stop_str in tail_str or stop_str in self.decoded_text:
741
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
742
743
                    return

744
745
746
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
747
        self.swa_uuid_for_lock = None
748
749
        self.extend_input_len = 0
        self.is_retracted = True
750
751
752
753
754
755
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
756
        self.already_computed = 0
757

Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
760
761
762
763
764
765
766
767
768
769
770
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

771
772
773
774
775
776
777
778
779
780
781
782
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

783
784
785
786
787
788
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
789
        self.return_logprob = False
790
791
792
793
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
794
    def __repr__(self):
795
        return (
796
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
797
798
799
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
800
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
801
802


Lianmin Zheng's avatar
Lianmin Zheng committed
803
# Batch id
804
805
806
bid = 0


807
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
808
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
809
    """Store all information of a batch on the scheduler."""
810

811
    # Request, memory pool, and cache
812
    reqs: List[Req]
813
    req_to_token_pool: ReqToTokenPool = None
814
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
815
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
816
    is_hybrid: bool = False
817

818
    # Batch configs
819
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
820
    forward_mode: ForwardMode = None
821
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
822
823
824
825
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
826

827
828
829
    # Events
    launch_done: Optional[threading.Event] = None

830
831
832
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

833
    # Sampling info
834
    sampling_info: SamplingBatchInfo = None
835
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
836

837
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
838
    input_ids: torch.Tensor = None  # shape: [b], int64
839
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
840
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
841
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
842
    seq_lens: torch.Tensor = None  # shape: [b], int64
843
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
844
845
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
846

847
848
849
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

850
851
852
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
853
854
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
855
    global_num_tokens_for_logprob: Optional[List[int]] = None
856
    is_extend_in_batch: bool = False
857
    can_run_dp_cuda_graph: bool = False
858
859
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
860

861
    # For processing logprobs
862
    return_logprob: bool = False
863
    top_logprobs_nums: Optional[List[int]] = None
864
    token_ids_logprobs: Optional[List[List[int]]] = None
865

Lianmin Zheng's avatar
Lianmin Zheng committed
866
867
868
869
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

870
871
872
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
873
    extend_num_tokens: Optional[int] = None
874
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
875
    extend_logprob_start_lens: List[int] = None
876
877
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
878

Lianmin Zheng's avatar
Lianmin Zheng committed
879
    # For encoder-decoder architectures
880
881
882
883
884
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

885
886
887
    # Stream
    has_stream: bool = False

888
889
    # Has grammar
    has_grammar: bool = False
890

891
    # Device
892
893
    device: str = "cuda"

894
    # Speculative decoding
895
    spec_algorithm: SpeculativeAlgorithm = None
896
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
897

898
899
900
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

901
902
903
    # Whether to return hidden states
    return_hidden_states: bool = False

904
905
906
    # hicache pointer for synchronizing data loading from CPU to GPU
    hicache_consumer_index: int = 0

907
    @classmethod
908
909
    def init_new(
        cls,
910
        reqs: List[Req],
911
        req_to_token_pool: ReqToTokenPool,
912
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
913
914
915
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
916
        spec_algorithm: SpeculativeAlgorithm,
917
        enable_custom_logit_processor: bool,
918
        chunked_req: Optional[Req] = None,
919
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
920
921
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
922
923
924
925
926
927
928
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
            assert isinstance(tree_cache, SWARadixCache) or isinstance(
                tree_cache, SWAChunkCache
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

929
930
931
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
932
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
933
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
934
            is_hybrid=is_hybrid,
935
            model_config=model_config,
936
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
937
            return_logprob=return_logprob,
938
            has_stream=any(req.stream for req in reqs),
939
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
940
            device=req_to_token_pool.device,
941
            spec_algorithm=spec_algorithm,
942
            enable_custom_logit_processor=enable_custom_logit_processor,
943
            return_hidden_states=any(req.return_hidden_states for req in reqs),
944
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
945
946
        )

947
    def batch_size(self):
948
        return len(self.reqs)
949

Lianmin Zheng's avatar
Lianmin Zheng committed
950
951
952
    def is_empty(self):
        return len(self.reqs) == 0

953
    def alloc_req_slots(self, num_reqs: int):
954
955
956
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
957
958
959
960
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
961
962
963
            )
        return req_pool_indices

964
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Hanming Lu's avatar
Hanming Lu committed
965
        self._evict_tree_cache_if_needed(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
966

967
968
969
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

970
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
971
972
973
974
975
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
976
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
977
978
979
980
981
982
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

983
984
985
986
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
989
990
991
992
993

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
994
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
995
    ):
Hanming Lu's avatar
Hanming Lu committed
996
997
        num_tokens = (
            extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
998
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
Hanming Lu's avatar
Hanming Lu committed
999
1000
        )
        self._evict_tree_cache_if_needed(num_tokens)
1001

1002
1003
1004
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
1005
1006
1007
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
1008
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1009
1010
1011
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1012
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1013
1014
1015
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1016
1017
1018
1019
1020

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
1022
1023
1024
1025

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1026
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1027
    ):
Hanming Lu's avatar
Hanming Lu committed
1028
1029
1030
        num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size

        self._evict_tree_cache_if_needed(num_tokens)
1031

1032
1033
1034
1035
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037
1038
1039
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1040
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1041
1042
1043
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1044
1045
1046
1047
1048

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1049

1050
1051
1052
1053
1054
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1055
            im = req.multimodal_inputs
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1067
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1080
                # NOTE: the encoder part should be considered as a whole
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1098
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1099
1100
            self.device, non_blocking=True
        )
1101
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1102
1103
1104
1105
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1106
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1107
1108
1109
1110
1111
1112
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1114
1115
1116
1117
1118
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1119
1120
1121
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1122

1123
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1124
1125
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1126
        # Allocate req slots
1127
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
        reqs = self.reqs
1132
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1133
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
1136
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1137

woodx's avatar
woodx committed
1138
1139
1140
1141
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
woodx's avatar
woodx committed
1154
1155
1156
1157
1158
1159
1160

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

Lianmin Zheng's avatar
Lianmin Zheng committed
1161
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1162

Lianmin Zheng's avatar
Lianmin Zheng committed
1163
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1164
        input_embeds = []
1165
        extend_input_logprob_token_ids = []
1166
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1167

Lianmin Zheng's avatar
Lianmin Zheng committed
1168
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1169
            req.req_pool_idx = req_pool_indices[i]
1170
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1171

1172
            if pre_len > 0:
1173
1174
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1175
                )
tarinkk's avatar
tarinkk committed
1176
                if isinstance(self.tree_cache, SWAChunkCache):
Hanming Lu's avatar
Hanming Lu committed
1177
                    self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1178
1179
                        req, pre_len, self.model_config.attention_chunk_size
                    )
1180

Rin Intachuen's avatar
Rin Intachuen committed
1181
1182
1183
1184
1185
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1186
1187
            multimodal_inputs.append(req.multimodal_inputs)

1188
1189
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1190
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1191

1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1246

Lianmin Zheng's avatar
Lianmin Zheng committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1260
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
1262
1263
1264
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1265
1266
1267
1268
1269
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1270
1271
1272
1273
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1274
                pixel_values = getattr(mm_item, "feature", None)
1275
                if isinstance(pixel_values, torch.Tensor):
1276
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1277
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1278
        self.token_type_ids = token_type_ids_tensor
1279
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1280

1281
1282
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1283
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1284

1285
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1286
1287
1288
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1289
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1290

1291
        # Write to req_to_token_pool
1292
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1293
1294
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1295
1296
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
1298
1299
1300
1301
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1302
1303
1304
1305
1306
1307
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1308
1309
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1310
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1311
                pt += extend_lens[i]
1312

1313
1314
1315
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1316
        # Build sampling info
1317
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1318
1319
            self,
            self.model_config.vocab_size,
1320
        )
1321

1322
1323
1324
1325
1326
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1327
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1328
        self.forward_mode = ForwardMode.MIXED
1329
        running_bs = running_batch.batch_size()
1330
1331
1332
1333
1334

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1335
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1336
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1337

1338
        self.merge_batch(running_batch)
1339
1340
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1341

1342
1343
1344
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1345
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1346
        self.prefix_lens.extend(
1347
            [
1348
                len(r.origin_input_ids) + len(r.output_ids) + delta
1349
1350
1351
                for r in running_batch.reqs
            ]
        )
1352
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1353
1354
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1355
        self.extend_logprob_start_lens.extend([0] * running_bs)
1356

1357
1358
1359
1360
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1361
1362
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1363
1364
1365
1366
1367
        return (
            sum(1 for req in self.reqs if req.seqlen % page_size == 0)
            if self.enable_overlap
            else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
        )
1368

1369
    def check_decode_mem(self, buf_multiplier=1):
Hanming Lu's avatar
Hanming Lu committed
1370
        num_tokens = (
1371
1372
1373
1374
1375
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1376
1377
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1378

1379
    def retract_decode(self, server_args: ServerArgs):
1380
        """Retract the decoding requests when there is not enough memory."""
1381
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1382
1383

        # TODO(lsyin): improve retraction policy for radix cache
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1409

Hanming Lu's avatar
Hanming Lu committed
1410
1411
1412
1413
1414
1415
1416
1417
1418
        def _get_available_size():
            if self.is_hybrid:
                return min(
                    self.token_to_kv_pool_allocator.full_available_size(),
                    self.token_to_kv_pool_allocator.swa_available_size(),
                )
            else:
                return self.token_to_kv_pool_allocator.available_size()

Lianmin Zheng's avatar
Lianmin Zheng committed
1419
1420
1421
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1422
        while (
Hanming Lu's avatar
Hanming Lu committed
1423
            _get_available_size() < get_required_tokens(len(sorted_indices))
1424
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1425
1426
1427
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1442
1443
                break

1444
            first_iter = False
1445
1446
1447
1448
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1449
1450
1451
1452
1453
            if server_args.disaggregation_mode == "decode":
                req.offload_kv_cache(
                    self.req_to_token_pool, self.token_to_kv_pool_allocator
                )

1454
1455
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1456
1457
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1458
                ]
1459
                self.token_to_kv_pool_allocator.free(token_indices)
1460
                self.req_to_token_pool.free(req.req_pool_idx)
1461
1462
            else:
                # TODO: apply more fine-grained retraction
1463
                last_uncached_pos = (
1464
1465
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1466
1467
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1468
                ]
1469
                self.token_to_kv_pool_allocator.free(token_indices)
1470
                self.req_to_token_pool.free(req.req_pool_idx)
1471
1472

                # release the last node
Hanming Lu's avatar
Hanming Lu committed
1473
1474
1475
1476
                if self.is_hybrid:
                    self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
                else:
                    self.tree_cache.dec_lock_ref(req.last_node)
1477
1478

                # NOTE(lsyin): we should use the newly evictable memory instantly.
Hanming Lu's avatar
Hanming Lu committed
1479
1480
                num_tokens = len(sorted_indices) * global_config.retract_decode_steps
                self._evict_tree_cache_if_needed(num_tokens)
1481

1482
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1483

1484
1485
1486
1487
1488
1489
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1490
        self.filter_batch(keep_indices=sorted_indices)
1491

Liangsheng Yin's avatar
Liangsheng Yin committed
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1502

1503
1504
1505
1506
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1507
1508
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1509
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1510
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1512
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1513
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1514
        self.extend_num_tokens = 0
1515
1516
1517
1518
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1519

1520
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1521
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1522
1523
        bs = len(self.reqs)

1524
        if self.spec_algorithm.is_eagle():
1525
1526
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1527
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1528

1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1552
        # Update fields
1553
1554
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1555

1556
1557
1558
1559
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1560
            locs = self.seq_lens.clone()
1561

1562
        if self.enable_overlap:
1563
1564
1565
1566
1567
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1568
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1569

tarinkk's avatar
tarinkk committed
1570
1571
1572
        # free memory
        if isinstance(self.tree_cache, SWAChunkCache):
            for req in self.reqs:
Hanming Lu's avatar
Hanming Lu committed
1573
                self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1574
1575
1576
                    req, req.seqlen - 1, self.model_config.attention_chunk_size
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1592
1593
    def filter_batch(
        self,
1594
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1595
1596
1597
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1598
1599
1600
1601
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1602
1603
1604
            keep_indices = [
                i
                for i in range(len(self.reqs))
1605
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1606
                and self.reqs[i] not in chunked_req_to_exclude
1607
1608
1609
            ]

        if keep_indices is None or len(keep_indices) == 0:
1610
1611
1612
1613
            # Filter out all requests
            self.reqs = []
            return

1614
        if len(keep_indices) == len(self.reqs):
1615
1616
1617
            # No need to filter
            return

1618
1619
1620
1621
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1622
        if self.model_config.is_encoder_decoder:
1623
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1624
1625
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1626
        self.reqs = [self.reqs[i] for i in keep_indices]
1627
1628
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1629
1630
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1631
        self.out_cache_loc = None
1632
        self.seq_lens_sum = self.seq_lens.sum().item()
1633
        self.output_ids = self.output_ids[keep_indices_device]
1634
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1635
        if self.return_logprob:
1636
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1637
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1638
1639
        else:
            self.top_logprobs_nums = None
1640
            self.token_ids_logprobs = None
1641

1642
        self.has_stream = any(req.stream for req in self.reqs)
1643
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1644

1645
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1646
        if self.spec_info:
1647
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1648

1649
    def merge_batch(self, other: "ScheduleBatch"):
1650
1651
1652
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1653
        self.sampling_info.merge_batch(other.sampling_info)
1654

1655
1656
1657
1658
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1659
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1660
1661
            [self.req_pool_indices, other.req_pool_indices]
        )
1662
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1663
        self.out_cache_loc = None
1664
        self.seq_lens_sum += other.seq_lens_sum
1665
        if self.output_ids is not None:
1666
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1667
1668
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1669
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1670
1671
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1672
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1673
1674
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1675
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1676
        self.reqs.extend(other.reqs)
1677
1678
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1679

1680
1681
1682
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1683
        self.return_hidden_states |= other.return_hidden_states
1684

1685
1686
1687
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1688
1689
1690
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1691
        if self.forward_mode.is_decode_or_idle():
1692
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1693
1694
1695
1696
1697
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1698
1699
        # Create seq_lens_cpu when needed
        if (
1700
1701
            global_server_args_dict["attention_backend"] == "fa3"
            or (
1702
1703
1704
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1705
            or global_server_args_dict["attention_backend"] == "flashmla"
1706
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1707
            or global_server_args_dict["attention_backend"] == "ascend"
1708
            or global_server_args_dict["enable_two_batch_overlap"]
1709
        ):
1710
1711
1712
1713
1714
            seq_lens_cpu = (
                seq_lens_cpu_cache
                if seq_lens_cpu_cache is not None
                else self.seq_lens.cpu()
            )
1715
1716
1717
        else:
            seq_lens_cpu = None

1718
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1719
1720
1721
1722
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1723

1724
1725
        global bid
        bid += 1
1726
        return ModelWorkerBatch(
1727
            bid=bid,
1728
1729
1730
1731
1732
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1733
            seq_lens_cpu=seq_lens_cpu,
1734
            seq_lens_sum=self.seq_lens_sum,
1735
1736
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1737
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1738
            global_num_tokens=self.global_num_tokens,
1739
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1740
            is_extend_in_batch=self.is_extend_in_batch,
1741
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1742
1743
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1744
            extend_num_tokens=self.extend_num_tokens,
1745
1746
1747
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1748
            multimodal_inputs=self.multimodal_inputs,
1749
1750
1751
1752
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1753
            lora_paths=[req.lora_path for req in self.reqs],
1754
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1755
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1756
            token_type_ids=self.token_type_ids,
1757
1758
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1759
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1760
            capture_hidden_mode=(
1761
                CaptureHiddenMode.FULL
1762
                if self.return_hidden_states
1763
1764
1765
1766
1767
1768
1769
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1770
            ),
1771
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1772
            launch_done=self.launch_done,
1773
1774
        )

1775
    def copy(self):
1776
        # Only contain fields that will be used by process_batch_result
1777
1778
        return ScheduleBatch(
            reqs=self.reqs,
1779
            model_config=self.model_config,
1780
            forward_mode=self.forward_mode,
1781
1782
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1783
            decoding_reqs=self.decoding_reqs,
1784
            spec_algorithm=self.spec_algorithm,
1785
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1786
1787
1788
1789
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1790
1791
        )

Hanming Lu's avatar
Hanming Lu committed
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
    def _evict_tree_cache_if_needed(
        self,
        num_tokens: int,
    ) -> None:
        if isinstance(self.tree_cache, SWAChunkCache):
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

    def _available_and_evictable_str(self) -> str:
        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
            full_evictable_size = self.tree_cache.full_evictable_size()
            swa_evictable_size = self.tree_cache.swa_evictable_size()
            return (
                f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
                f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
                f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
                f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
            )
        else:
            available_size = self.token_to_kv_pool_allocator.available_size()
            evictable_size = self.tree_cache.evictable_size()
            return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"

1839
1840
    def __str__(self):
        return (
1841
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1842
1843
1844
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1845

1846
@dataclasses.dataclass
1847
class ModelWorkerBatch:
1848
1849
    # The batch id
    bid: int
1850
1851
1852
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1853
    input_ids: torch.Tensor
1854
1855
1856
1857
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1858
    # The indices of output tokens in the token_to_kv_pool_allocator
1859
    out_cache_loc: torch.Tensor
1860
1861
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1862
1863
    seq_lens_sum: int

1864
1865
1866
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1867
    token_ids_logprobs: Optional[List[List[int]]]
1868

Ke Bao's avatar
Ke Bao committed
1869
1870
    # For DP attention
    global_num_tokens: Optional[List[int]]
1871
    global_num_tokens_for_logprob: Optional[List[int]]
1872
    is_extend_in_batch: bool
1873
    can_run_dp_cuda_graph: bool
1874
1875
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1876

1877
    # For extend
1878
    extend_num_tokens: Optional[int]
1879
1880
1881
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1882
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1883
1884

    # For multimodal
Mick's avatar
Mick committed
1885
    multimodal_inputs: Optional[List[MultimodalInputs]]
1886

1887
1888
1889
1890
1891
1892
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1893
1894
1895
1896
1897
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1898

Rin Intachuen's avatar
Rin Intachuen committed
1899
1900
1901
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

woodx's avatar
woodx committed
1902
1903
1904
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1905
    # Speculative decoding
1906
    spec_algorithm: SpeculativeAlgorithm = None
1907
1908
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1909
    capture_hidden_mode: CaptureHiddenMode = None
1910
    spec_num_draft_tokens: Optional[int] = None
1911
    hicache_consumer_index: int = 0
1912

1913
1914
1915
    # Overlap event
    launch_done: Optional[threading.Event] = None

1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1934
1935
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1952
1953


1954
1955
1956
1957
1958
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
1959
1960
1961
1962
    if (
        global_server_args_dict["attention_backend"] != "ascend"
        and global_server_args_dict["attention_backend"] != "torch_native"
    ):
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1975
1976
1977
1978
1979
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result