"vscode:/vscode.git/clone" did not exist on "8819288a52136b034bfe3b3501335c2d63b1418c"
schedule_batch.py 71.1 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from http import HTTPStatus
41
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
44
import torch
45
46
import triton
import triton.language as tl
47

Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.global_config import global_config
49
from sglang.srt.configs.model_config import ModelConfig
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Mick's avatar
Mick committed
56
from sglang.srt.layers.multimodal import gpu_tensor_hash
57
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
58
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
59
from sglang.srt.mem_cache.chunk_cache import ChunkCache
60
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
61
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
63
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
64
from sglang.srt.sampling.sampling_params import SamplingParams
65
from sglang.srt.server_args import ServerArgs
66
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
67

68
if TYPE_CHECKING:
69
70
71
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
72
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
73

74
75
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
76
    "mm_attention_backend",
77
78
79
80
81
82
83
84
85
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
    "disable_radix_cache",
    "enable_dp_attention",
    "enable_two_batch_overlap",
    "enable_dp_lm_head",
86
87
    "enable_deepep_moe",
    "deepep_mode",
88
    "enable_ep_moe",
89
    "enable_flashinfer_moe",
90
91
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
92
    "deepep_config",
93
    "ep_num_redundant_experts",
94
95
96
97
98
99
100
101
102
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_acc",
    "speculative_accept_threshold_single",
    "torchao_config",
    "triton_attention_reduce_in_fp32",
103
    "num_reserved_decode_tokens",
104
    "weight_loader_disable_mmap",
105
106
]

107
# Put some global args for easy access
108
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
109

Ying Sheng's avatar
Ying Sheng committed
110
111
112
logger = logging.getLogger(__name__)


113
114
115
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
116

117
    def to_json(self):
118
        raise NotImplementedError()
119
120
121


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
122
    def __init__(self, matched: Union[int, List[int]]):
123
124
125
        super().__init__()
        self.matched = matched

126
127
128
129
130
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
131
132


133
134
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
135
        super().__init__()
136
        self.matched = matched
137

138
139
140
141
142
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
143
144


145
146
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
147
        super().__init__()
148
        self.length = length
149

150
151
152
153
154
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
155
156
157


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
158
    def __init__(self, message=None, status_code=None, err_type=None):
159
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
160
        self.message = message or "Aborted"
161
162
        self.status_code = status_code
        self.err_type = err_type
163

164
165
166
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
167
            "message": self.message,
168
169
            "status_code": self.status_code,
            "err_type": self.err_type,
170
        }
171

Lianmin Zheng's avatar
Lianmin Zheng committed
172

Mick's avatar
Mick committed
173
174
175
176
177
178
179
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


180
@dataclasses.dataclass
Mick's avatar
Mick committed
181
182
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
183
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
184
    """
185

Mick's avatar
Mick committed
186
187
188
189
190
191
192
193
194
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
195
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
196
197

    # the real data, pixel_values or audio_features
198
199
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
200
    image_grid_thw: Union[torch.Tensor, np.ndarray] = None
201
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
202
203
204
205
206
207
208
209

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

210
211
212
    # kimi-vl related
    image_grid_hws: Optional[List[torch.Tensor]] = None

213
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
214
    audio_feature_lens: Optional[List[torch.Tensor]] = None
215
    audio_offsets: Optional[List[Tuple[int, int]]] = None
Mick's avatar
Mick committed
216

217
218
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
219
220
221
222
223
224
225
226
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
227
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
228
229
        """

Mick's avatar
Mick committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
246
247
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
248
249
250
251
252
253
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

254
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
255
            if tensor.is_cuda:
256
257
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
258
259
260
261
262
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
263

Mick's avatar
Mick committed
264
265
        def hash_feature(f):
            if isinstance(f, list):
266
267
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
268
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
269
270
271
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
272
273
274
275
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
276

277
278
279
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
280
281
282
283
284
285
286
287
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
288
289
290
291
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
292
293
294
295

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
296
297
298
299
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
300
301

    def is_video(self):
302
303
304
305
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
306

307
308
309
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
310
311
312
313
    def validate(self):
        ...
        # TODO

314
315
316
317
318
319
320
321
322
323
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
324
325
326
327
328
329
330

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
331
    image_pad_len: Optional[list] = None
332
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
333

Yineng Zhang's avatar
Yineng Zhang committed
334
    # QWen2-VL related
335
    mrope_positions: Optional[torch.Tensor] = None
336
    mrope_position_delta: Optional[torch.Tensor] = None
337

Mick's avatar
Mick committed
338
    # image
Mick's avatar
Mick committed
339
    im_token_id: Optional[int] = None
340
341
342
343
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
344
345
346

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
347

Mick's avatar
Mick committed
348
    # audio
349
350
351
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
352

Liangsheng Yin's avatar
Liangsheng Yin committed
353
    @staticmethod
354
    def from_dict(obj: dict):
Mick's avatar
Mick committed
355
        ret = MultimodalInputs(
Mick's avatar
Mick committed
356
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
357
        )
358

Mick's avatar
Mick committed
359
        assert isinstance(ret.mm_items, list)
360
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
361
362
363

        for item in ret.mm_items:
            item.set_pad_value()
364
365

        optional_args = [
366
367
            "mrope_positions",
            "mrope_position_delta",
368
            "im_token_id",
Mick's avatar
Mick committed
369
370
371
372
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
373
374
            "audio_start_id",
            "audio_end_id",
375
            "audio_token_id",
376
377
378
379
380
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
381
382
        return ret

Mick's avatar
Mick committed
383
384
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
385
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
386
387
388

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
389
390
        return any(item.is_audio() for item in self.mm_items)

391
392
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
393
394

    def merge(self, other: MultimodalInputs):
395
396
397
        """
        merge image inputs when requests are being merged
        """
398

399
        # args needed to be merged
400
        optional_args = [
Mick's avatar
Mick committed
401
            "mm_items",
402
            "image_pad_len",
403
404
        ]
        for arg in optional_args:
405
406
407
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
408
409
410
411
412
413
414
415
416
417

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

418
419
420
421
422
423
424
425
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
426
427
428
429
430
431

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
432
        # other args would be kept intact
433

Liangsheng Yin's avatar
Liangsheng Yin committed
434

Lianmin Zheng's avatar
Lianmin Zheng committed
435
class Req:
436
    """The input and output status of a request."""
437

438
439
440
441
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
442
        origin_input_ids: List[int],
443
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
444
445
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
446
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
447
        stream: bool = False,
448
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
449
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
450
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
451
        token_type_ids: List[int] = None,
452
        session_id: Optional[str] = None,
453
        custom_logit_processor: Optional[str] = None,
454
        return_hidden_states: bool = False,
455
        eos_token_ids: Optional[Set[int]] = None,
456
        bootstrap_host: Optional[str] = None,
457
        bootstrap_port: Optional[int] = None,
458
        bootstrap_room: Optional[int] = None,
459
        data_parallel_rank: Optional[int] = None,
460
    ):
461
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
462
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
463
        self.origin_input_text = origin_input_text
464
465
466
467
468
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
469
        self.origin_input_ids = origin_input_ids
470
471
472
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
473
        self.fill_ids = []
474
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
475
        self.input_embeds = input_embeds
476

woodx's avatar
woodx committed
477
478
479
        # for corss-endoder model
        self.token_type_ids = token_type_ids

Lianmin Zheng's avatar
Lianmin Zheng committed
480
        # Sampling info
481
482
483
484
485
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
486
        self.sampling_params = sampling_params
487
        self.custom_logit_processor = custom_logit_processor
488
        self.return_hidden_states = return_hidden_states
489
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
490

491
        # Memory pool info
492
        self.req_pool_idx: Optional[int] = None
493

494
495
496
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
497
498
        # Whether this request has finished output
        self.finished_output = None
499
500
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
501
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
502
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
503
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
504
        self.stream = stream
505
        self.eos_token_ids = eos_token_ids
506

507
        # For incremental decoding
508
509
510
511
512
513
514
515
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
516
517
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
518
        self.decoded_text = ""
519

520
        # For multimodal inputs
Mick's avatar
Mick committed
521
        self.multimodal_inputs: Optional[MultimodalInputs] = None
522

523
        # Prefix info
524
        # The indices to kv cache for the shared prefix.
525
        self.prefix_indices: torch.Tensor = []
526
        # Number of tokens to run prefill.
527
        self.extend_input_len = 0
528
529
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
530
531
532
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
533

534
535
536
537
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
538

539
540
541
        # For retraction
        self.is_retracted = False

542
543
544
545
546
547
548
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

549
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
550
        self.return_logprob = return_logprob
551
        # Start index to compute logprob from.
552
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
553
        self.top_logprobs_num = top_logprobs_num
554
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
557

558
        # Logprobs (return values)
559
560
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
561
562
563
564
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
565
566
567
568
569
570
571
572
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574

        if return_logprob:
575
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
578
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
579
580
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
581
582
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
583
584
585
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
586
587
588
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
589
        self.hidden_states: List[List[float]] = []
590
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
591

592
        # Embedding (return values)
593
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
594

595
        # Constrained decoding
596
        self.grammar: Optional[BaseGrammarObject] = None
597
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
598

599
        # The number of cached tokens that were already cached in the KV cache
600
        self.cached_tokens = 0
601
        self.already_computed = 0
602

603
604
605
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
606
607
608
609
610
611

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
612

Byron Hsu's avatar
Byron Hsu committed
613
        # For disaggregation
614
        self.bootstrap_host: str = bootstrap_host
615
        self.bootstrap_port: Optional[int] = bootstrap_port
616
        self.bootstrap_room: Optional[int] = bootstrap_room
617
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
618

619
620
621
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
622
623
624
625
626
627
628
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

629
630
631
632
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
633
        self.metadata_buffer_index: int = -1
634

635
636
637
638
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

639
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
640
641
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
642
        else:
Mick's avatar
Mick committed
643
            self.multimodal_inputs.merge(image_inputs)
644

645
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
646
        # Whether request reached finished condition
647
648
        return self.finished_reason is not None

649
650
651
652
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
    ):
653
        self.fill_ids = self.origin_input_ids + self.output_ids
654
        if tree_cache is not None:
655
656
657
658
659
660
661
662
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
                key=self.adjust_max_prefix_ids(),
            )
663
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
664

665
    def adjust_max_prefix_ids(self):
666
667
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
668
669
670
671

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
672
673
674
675
676

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

677
        if self.return_logprob:
678
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
679

680
        max_prefix_len = max(max_prefix_len, 0)
681
        return self.fill_ids[:max_prefix_len]
682

Liangsheng Yin's avatar
Liangsheng Yin committed
683
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
684
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
685
686
687
688
689
690
691
692
693
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
694
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
695

696
    def check_finished(self):
697
        if self.finished():
698
699
            return

700
        if self.to_abort:
701
702
703
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
704
705
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
706
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
707
708
709
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
710
711
            return

712
713
714
715
716
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

717
        last_token_id = self.output_ids[-1]
718

719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
736

737
        # Check stop strings
738
739
740
741
742
743
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
744
                if stop_str in tail_str or stop_str in self.decoded_text:
745
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
746
747
                    return

748
749
750
751
752
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
753
754
755
756
757
758
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
759
        self.already_computed = 0
760

Lianmin Zheng's avatar
Lianmin Zheng committed
761
762
763
764
765
766
767
768
769
770
771
772
773
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

774
775
776
777
778
779
780
781
782
783
784
785
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

786
787
788
789
790
791
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
792
        self.return_logprob = False
793
794
795
796
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
797
    def __repr__(self):
798
        return (
799
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
800
801
802
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
803
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
804
805


Lianmin Zheng's avatar
Lianmin Zheng committed
806
# Batch id
807
808
809
bid = 0


810
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
811
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
812
    """Store all information of a batch on the scheduler."""
813

814
    # Request, memory pool, and cache
815
    reqs: List[Req]
816
    req_to_token_pool: ReqToTokenPool = None
817
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
818
    tree_cache: BasePrefixCache = None
819

820
    # Batch configs
821
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
822
    forward_mode: ForwardMode = None
823
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
824
825
826
827
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
828

829
830
831
    # Events
    launch_done: Optional[threading.Event] = None

832
833
834
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

835
    # Sampling info
836
    sampling_info: SamplingBatchInfo = None
837
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
838

839
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
840
    input_ids: torch.Tensor = None  # shape: [b], int64
841
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
842
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
843
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
844
    seq_lens: torch.Tensor = None  # shape: [b], int64
845
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
846
847
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
848

849
850
851
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

852
853
854
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
855
856
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
857
    global_num_tokens_for_logprob: Optional[List[int]] = None
858
    can_run_dp_cuda_graph: bool = False
859
    is_extend_in_batch: bool = False
860
861
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
862

863
    # For processing logprobs
864
    return_logprob: bool = False
865
    top_logprobs_nums: Optional[List[int]] = None
866
    token_ids_logprobs: Optional[List[List[int]]] = None
867

Lianmin Zheng's avatar
Lianmin Zheng committed
868
869
870
871
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

872
873
874
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
875
    extend_num_tokens: Optional[int] = None
876
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
877
    extend_logprob_start_lens: List[int] = None
878
879
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
880

Lianmin Zheng's avatar
Lianmin Zheng committed
881
    # For encoder-decoder architectures
882
883
884
885
886
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

887
888
889
    # Stream
    has_stream: bool = False

890
891
    # Has grammar
    has_grammar: bool = False
892

893
    # Device
894
895
    device: str = "cuda"

896
    # Speculative decoding
897
    spec_algorithm: SpeculativeAlgorithm = None
898
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
899

900
901
902
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

903
904
905
    # Whether to return hidden states
    return_hidden_states: bool = False

906
907
908
    # hicache pointer for synchronizing data loading from CPU to GPU
    hicache_consumer_index: int = 0

909
    @classmethod
910
911
    def init_new(
        cls,
912
        reqs: List[Req],
913
        req_to_token_pool: ReqToTokenPool,
914
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
915
916
917
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
918
        spec_algorithm: SpeculativeAlgorithm,
919
        enable_custom_logit_processor: bool,
920
        chunked_req: Optional[Req] = None,
921
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
        return_logprob = any(req.return_logprob for req in reqs)

924
925
926
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
927
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
928
            tree_cache=tree_cache,
929
            model_config=model_config,
930
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
931
            return_logprob=return_logprob,
932
            has_stream=any(req.stream for req in reqs),
933
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
934
            device=req_to_token_pool.device,
935
            spec_algorithm=spec_algorithm,
936
            enable_custom_logit_processor=enable_custom_logit_processor,
937
            return_hidden_states=any(req.return_hidden_states for req in reqs),
938
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
939
940
        )

941
    def batch_size(self):
942
        return len(self.reqs)
943

Lianmin Zheng's avatar
Lianmin Zheng committed
944
945
946
    def is_empty(self):
        return len(self.reqs) == 0

947
    def alloc_req_slots(self, num_reqs: int):
948
949
950
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
951
952
953
954
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
955
956
957
            )
        return req_pool_indices

958
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
959
960
961
962
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

963
964
965
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

966
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
967
968
969
970
971
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
972
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
973
974
975
976
977
978
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

979
980
981
982
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
983
984
985
986
987
988
989

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
990
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
991
992
993
994
995
996
997
998
999
1000
1001
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
1002

1003
1004
1005
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
1006
1007
1008
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
1009
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1010
1011
1012
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
1013
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
1015
1016
1017
1018
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1019
1020
1021
1022
1023

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1024
1025
1026
1027
1028

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1029
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1030
    ):
1031
1032
1033
1034
1035
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1038
                )
1039

1040
1041
1042
1043
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1044
1045
1046
1047
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1048
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1049
1050
1051
1052
1053
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1054
1055
1056
1057
1058

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1059

1060
1061
1062
1063
1064
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1065
            im = req.multimodal_inputs
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1077
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1090
                # NOTE: the encoder part should be considered as a whole
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1108
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1109
1110
            self.device, non_blocking=True
        )
1111
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1112
1113
1114
1115
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1116
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1117
1118
1119
1120
1121
1122
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1123
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1124
1125
1126
1127
1128
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1129
1130
1131
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1132

1133
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1134
1135
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1136
        # Allocate req slots
1137
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1138
1139
1140
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1141
        reqs = self.reqs
1142
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1143
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
1145
1146
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1147

woodx's avatar
woodx committed
1148
1149
1150
1151
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
woodx's avatar
woodx committed
1164
1165
1166
1167
1168
1169
1170

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

Lianmin Zheng's avatar
Lianmin Zheng committed
1171
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1172

Lianmin Zheng's avatar
Lianmin Zheng committed
1173
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1174
        input_embeds = []
1175
        extend_input_logprob_token_ids = []
1176
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1177

Lianmin Zheng's avatar
Lianmin Zheng committed
1178
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1179
            req.req_pool_idx = req_pool_indices[i]
1180
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1181

1182
            if pre_len > 0:
1183
1184
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1185
                )
1186

Rin Intachuen's avatar
Rin Intachuen committed
1187
1188
1189
1190
1191
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1192
1193
            multimodal_inputs.append(req.multimodal_inputs)

1194
1195
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1196
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1252

Lianmin Zheng's avatar
Lianmin Zheng committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1266
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1267
1268
1269
1270
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1271
1272
1273
1274
1275
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1286
        self.token_type_ids = token_type_ids_tensor
1287
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1288

1289
1290
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1291
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1292

1293
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1294
1295
1296
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1297
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1298

1299
        # Write to req_to_token_pool
1300
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
1302
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1303
1304
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1305
1306
1307
1308
1309
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1310
1311
1312
1313
1314
1315
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
1317
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1318
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
                pt += extend_lens[i]
1320

1321
1322
1323
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1324
        # Build sampling info
1325
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1326
1327
            self,
            self.model_config.vocab_size,
1328
        )
1329

1330
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1331
        self.forward_mode = ForwardMode.MIXED
1332
        running_bs = running_batch.batch_size()
1333
1334
1335
1336
1337

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1338
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1339
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1340

1341
        self.merge_batch(running_batch)
1342
1343
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1344

1345
1346
1347
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1348
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1349
        self.prefix_lens.extend(
1350
            [
1351
                len(r.origin_input_ids) + len(r.output_ids) + delta
1352
1353
1354
                for r in running_batch.reqs
            ]
        )
1355
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1356
1357
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1358
        self.extend_logprob_start_lens.extend([0] * running_bs)
1359

1360
1361
1362
1363
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1364
1365
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1366
1367
1368
1369
1370
        return (
            sum(1 for req in self.reqs if req.seqlen % page_size == 0)
            if self.enable_overlap
            else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
        )
1371

1372
1373
1374
1375
1376
1377
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1378

1379
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1380
1381
            return True

1382
1383
1384
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1385

1386
    def retract_decode(self, server_args: ServerArgs):
1387
        """Retract the decoding requests when there is not enough memory."""
1388
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1389
1390

        # TODO(lsyin): improve retraction policy for radix cache
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1416

Lianmin Zheng's avatar
Lianmin Zheng committed
1417
1418
1419
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1420
        while (
1421
            self.token_to_kv_pool_allocator.available_size()
1422
            < get_required_tokens(len(sorted_indices))
1423
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1424
1425
1426
1427
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1428
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1429
1430
1431
                ), "No space left for only one request"
                break

1432
            first_iter = False
1433
1434
1435
1436
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1437
1438
1439
1440
1441
            if server_args.disaggregation_mode == "decode":
                req.offload_kv_cache(
                    self.req_to_token_pool, self.token_to_kv_pool_allocator
                )

1442
1443
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1444
1445
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1446
                ]
1447
                self.token_to_kv_pool_allocator.free(token_indices)
1448
                self.req_to_token_pool.free(req.req_pool_idx)
1449
1450
            else:
                # TODO: apply more fine-grained retraction
1451
                last_uncached_pos = (
1452
1453
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1454
1455
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1456
                ]
1457
                self.token_to_kv_pool_allocator.free(token_indices)
1458
                self.req_to_token_pool.free(req.req_pool_idx)
1459
1460
1461
1462
1463
1464
1465

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1466
                    - self.token_to_kv_pool_allocator.available_size()
1467
1468
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1469
                self.tree_cache.evict(residual_size)
1470

1471
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1472

1473
1474
1475
1476
1477
1478
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1479
        self.filter_batch(keep_indices=sorted_indices)
1480

Liangsheng Yin's avatar
Liangsheng Yin committed
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1491

1492
1493
1494
1495
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1496
1497
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1498
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1499
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1500
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1501
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1502
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1503
        self.extend_num_tokens = 0
1504
1505
1506
1507
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1508

1509
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1510
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
1512
        bs = len(self.reqs)

1513
        if self.spec_algorithm.is_eagle():
1514
1515
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1516
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1517

1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1541
        # Update fields
1542
1543
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1544

1545
1546
1547
1548
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1549
            locs = self.seq_lens.clone()
1550

1551
        if self.enable_overlap:
1552
1553
1554
1555
1556
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1557
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1558

Lianmin Zheng's avatar
Lianmin Zheng committed
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1574
1575
    def filter_batch(
        self,
1576
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1577
1578
1579
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1580
1581
1582
1583
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1584
1585
1586
            keep_indices = [
                i
                for i in range(len(self.reqs))
1587
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1588
                and self.reqs[i] not in chunked_req_to_exclude
1589
1590
1591
            ]

        if keep_indices is None or len(keep_indices) == 0:
1592
1593
1594
1595
            # Filter out all requests
            self.reqs = []
            return

1596
        if len(keep_indices) == len(self.reqs):
1597
1598
1599
            # No need to filter
            return

1600
1601
1602
1603
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1604
        if self.model_config.is_encoder_decoder:
1605
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1606
1607
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1608
        self.reqs = [self.reqs[i] for i in keep_indices]
1609
1610
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1611
1612
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1613
        self.out_cache_loc = None
1614
        self.seq_lens_sum = self.seq_lens.sum().item()
1615
        self.output_ids = self.output_ids[keep_indices_device]
1616
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1617
        if self.return_logprob:
1618
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1619
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1620
1621
        else:
            self.top_logprobs_nums = None
1622
            self.token_ids_logprobs = None
1623

1624
        self.has_stream = any(req.stream for req in self.reqs)
1625
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1626

1627
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1628
        if self.spec_info:
1629
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1630

1631
    def merge_batch(self, other: "ScheduleBatch"):
1632
1633
1634
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1635
        self.sampling_info.merge_batch(other.sampling_info)
1636

1637
1638
1639
1640
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1641
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1642
1643
            [self.req_pool_indices, other.req_pool_indices]
        )
1644
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1645
        self.out_cache_loc = None
1646
        self.seq_lens_sum += other.seq_lens_sum
1647
        if self.output_ids is not None:
1648
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1649
1650
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1651
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1652
1653
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1654
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1655
1656
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1657
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1658
        self.reqs.extend(other.reqs)
1659
1660
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1661

1662
1663
1664
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1665
        self.return_hidden_states |= other.return_hidden_states
1666

1667
1668
1669
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1670
1671
1672
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1673
        if self.forward_mode.is_decode_or_idle():
1674
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1675
1676
1677
1678
1679
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1680
1681
        # Create seq_lens_cpu when needed
        if (
1682
1683
            global_server_args_dict["attention_backend"] == "fa3"
            or (
1684
1685
1686
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1687
            or global_server_args_dict["attention_backend"] == "flashmla"
1688
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1689
            or global_server_args_dict["enable_two_batch_overlap"]
1690
        ):
1691
1692
1693
1694
1695
            seq_lens_cpu = (
                seq_lens_cpu_cache
                if seq_lens_cpu_cache is not None
                else self.seq_lens.cpu()
            )
1696
1697
1698
        else:
            seq_lens_cpu = None

1699
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1700
1701
1702
1703
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1704

1705
1706
        global bid
        bid += 1
1707
        return ModelWorkerBatch(
1708
            bid=bid,
1709
1710
1711
1712
1713
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1714
            seq_lens_cpu=seq_lens_cpu,
1715
            seq_lens_sum=self.seq_lens_sum,
1716
1717
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1718
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1719
            global_num_tokens=self.global_num_tokens,
1720
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1721
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1722
1723
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1724
            extend_num_tokens=self.extend_num_tokens,
1725
1726
1727
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1728
            multimodal_inputs=self.multimodal_inputs,
1729
1730
1731
1732
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1733
            lora_paths=[req.lora_path for req in self.reqs],
1734
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1735
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1736
            token_type_ids=self.token_type_ids,
1737
1738
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1739
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1740
            capture_hidden_mode=(
1741
                CaptureHiddenMode.FULL
1742
                if self.return_hidden_states
1743
1744
1745
1746
1747
1748
1749
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1750
            ),
1751
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1752
            launch_done=self.launch_done,
1753
1754
        )

1755
    def copy(self):
1756
        # Only contain fields that will be used by process_batch_result
1757
1758
        return ScheduleBatch(
            reqs=self.reqs,
1759
            model_config=self.model_config,
1760
            forward_mode=self.forward_mode,
1761
1762
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1763
            decoding_reqs=self.decoding_reqs,
1764
            spec_algorithm=self.spec_algorithm,
1765
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1766
1767
1768
1769
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1770
1771
1772
1773
        )

    def __str__(self):
        return (
1774
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1775
1776
1777
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1778

1779
@dataclasses.dataclass
1780
class ModelWorkerBatch:
1781
1782
    # The batch id
    bid: int
1783
1784
1785
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1786
    input_ids: torch.Tensor
1787
1788
1789
1790
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1791
    # The indices of output tokens in the token_to_kv_pool_allocator
1792
1793
    out_cache_loc: torch.Tensor

1794
1795
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1796
1797
    seq_lens_sum: int

1798
1799
1800
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1801
    token_ids_logprobs: Optional[List[List[int]]]
1802

Ke Bao's avatar
Ke Bao committed
1803
1804
    # For DP attention
    global_num_tokens: Optional[List[int]]
1805
    global_num_tokens_for_logprob: Optional[List[int]]
1806
    can_run_dp_cuda_graph: bool
1807
1808
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1809

1810
    # For extend
1811
    extend_num_tokens: Optional[int]
1812
1813
1814
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1815
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1816
1817

    # For multimodal
Mick's avatar
Mick committed
1818
    multimodal_inputs: Optional[List[MultimodalInputs]]
1819

1820
1821
1822
1823
1824
1825
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1826
1827
1828
1829
1830
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1831

Rin Intachuen's avatar
Rin Intachuen committed
1832
1833
1834
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

woodx's avatar
woodx committed
1835
1836
1837
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1838
    # Speculative decoding
1839
    spec_algorithm: SpeculativeAlgorithm = None
1840
1841
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1842
    capture_hidden_mode: CaptureHiddenMode = None
1843
    spec_num_draft_tokens: Optional[int] = None
1844
    hicache_consumer_index: int = 0
1845

1846
1847
1848
    # Overlap event
    launch_done: Optional[threading.Event] = None

1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1867
1868
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1885
1886


1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    if global_server_args_dict["attention_backend"] != "torch_native":
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1905
1906
1907
1908
1909
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result