schedule_batch.py 62.1 KB
Newer Older
1
2
from __future__ import annotations

Mick's avatar
Mick committed
3
import hashlib
Mick's avatar
Mick committed
4
5
from enum import Enum, auto

6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
19
20
21
22
23
24
25
26
27
28
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
29
30
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
31
32
33
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
34

35
import copy
36
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
41
import torch
42
43
import triton
import triton.language as tl
44

Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.global_config import global_config
46
from sglang.srt.configs.model_config import ModelConfig
47
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
48
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
49
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
50
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
from sglang.srt.mem_cache.chunk_cache import ChunkCache
52
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
54
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
55
from sglang.srt.sampling.sampling_params import SamplingParams
56
from sglang.srt.server_args import ServerArgs
Mick's avatar
Mick committed
57
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
58

59
if TYPE_CHECKING:
60
61
62
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
63
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
66
# Put some global args for easy access
global_server_args_dict = {
67
68
69
70
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
    "torchao_config": ServerArgs.torchao_config,
71
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
72
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
73
    "enable_ep_moe": ServerArgs.enable_ep_moe,
74
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
    "deepep_mode": ServerArgs.deepep_mode,
76
    "device": ServerArgs.device,
77
78
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
79
    "disable_radix_cache": ServerArgs.disable_radix_cache,
80
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
81
    "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
82
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
83
    "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
84
    "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
85
86
}

Ying Sheng's avatar
Ying Sheng committed
87
88
89
logger = logging.getLogger(__name__)


90
91
92
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
93

94
    def to_json(self):
95
        raise NotImplementedError()
96
97
98


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
99
    def __init__(self, matched: Union[int, List[int]]):
100
101
102
        super().__init__()
        self.matched = matched

103
104
105
106
107
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
108
109


110
111
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
112
        super().__init__()
113
        self.matched = matched
114

115
116
117
118
119
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
120
121


122
123
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
124
        super().__init__()
125
        self.length = length
126

127
128
129
130
131
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
132
133
134


class FINISH_ABORT(BaseFinishReason):
135
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
136
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
137
        self.message = message
138
139
        self.status_code = status_code
        self.err_type = err_type
140

141
142
143
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
144
            "message": self.message,
145
146
            "status_code": self.status_code,
            "err_type": self.err_type,
147
        }
148

Lianmin Zheng's avatar
Lianmin Zheng committed
149

Mick's avatar
Mick committed
150
151
152
153
154
155
156
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


157
@dataclasses.dataclass
Mick's avatar
Mick committed
158
159
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
160
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
161
    """
162

Mick's avatar
Mick committed
163
164
165
166
167
168
169
170
171
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
172
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    # the real data, pixel_values or audio_features
    # data: Union[List[torch.Tensor], List[np.array]]
    pixel_values: Union[torch.Tensor, np.array] = None
    image_grid_thws: Union[torch.Tensor, np.array] = None
    video_grid_thws: Union[torch.Tensor, np.array] = None

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

    audio_features: Union[torch.Tensor, np.array] = None
    audio_feature_lens: Optional[List[torch.Tensor]] = None

    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
198
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
199
200
        """

Mick's avatar
Mick committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)

            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

224
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
225
            if tensor.is_cuda:
226
227
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
228
229
230
231
232
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
233

Mick's avatar
Mick committed
234
235
        def hash_feature(f):
            if isinstance(f, list):
236
237
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
238
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
239
240
241
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
242
243
244
245
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

        if self.is_audio():
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
        return (
            self.modality == Modality.AUDIO
        ) and not MultimodalDataItem.is_empty_list(self.audio_features)

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

    def is_video(self):
        return (
            self.modality == Modality.VIDEO
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

270
271
272
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
273
274
275
276
277
278
279
280
281
282
283
    def validate(self):
        ...
        # TODO


@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
284
    image_pad_len: Optional[list] = None
285
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
286

Yineng Zhang's avatar
Yineng Zhang committed
287
    # QWen2-VL related
288
    mrope_positions: Optional[torch.Tensor] = None
289
    mrope_position_delta: Optional[torch.Tensor] = None
290

Mick's avatar
Mick committed
291
    # image
Mick's avatar
Mick committed
292
    im_token_id: Optional[int] = None
293
294
295
296
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
297
298
299

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
300

Mick's avatar
Mick committed
301
302
303
304
    # audio
    audio_start_id: Optional[torch.Tensor] = None
    audio_end_id: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
305
    @staticmethod
306
    def from_dict(obj: dict):
Mick's avatar
Mick committed
307
        ret = MultimodalInputs(
Mick's avatar
Mick committed
308
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
309
        )
310

Mick's avatar
Mick committed
311
        assert isinstance(ret.mm_items, list)
312
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
313
314
315

        for item in ret.mm_items:
            item.set_pad_value()
316
317

        optional_args = [
318
319
            "mrope_positions",
            "mrope_position_delta",
320
            "im_token_id",
Mick's avatar
Mick committed
321
322
323
324
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
325
326
            "audio_start_id",
            "audio_end_id",
327
328
329
330
331
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
332
333
        return ret

Mick's avatar
Mick committed
334
335
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
336
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
337
338
339

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
340
341
        return any(item.is_audio() for item in self.mm_items)

342
343
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
344
345

    def merge(self, other: MultimodalInputs):
346
347
348
        """
        merge image inputs when requests are being merged
        """
349

350
        # args needed to be merged
351
        optional_args = [
Mick's avatar
Mick committed
352
            "mm_items",
353
            "image_pad_len",
354
355
        ]
        for arg in optional_args:
356
357
358
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
359
360
361
362
363
364
365
366
367
368

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

369
370
371
372
373
374
375
376
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
377
378
379
380
381
382

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
383
        # other args would be kept intact
384

Liangsheng Yin's avatar
Liangsheng Yin committed
385

Lianmin Zheng's avatar
Lianmin Zheng committed
386
class Req:
387
    """The input and output status of a request."""
388

389
390
391
392
393
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
394
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
397
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
398
        stream: bool = False,
399
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
400
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
401
        input_embeds: Optional[List[List[float]]] = None,
402
        session_id: Optional[str] = None,
403
        custom_logit_processor: Optional[str] = None,
404
        return_hidden_states: bool = False,
405
        eos_token_ids: Optional[Set[int]] = None,
406
        bootstrap_host: Optional[str] = None,
407
        bootstrap_port: Optional[int] = None,
408
        bootstrap_room: Optional[int] = None,
409
    ):
410
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
411
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
412
        self.origin_input_text = origin_input_text
413
414
415
416
417
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
418
        self.origin_input_ids = origin_input_ids
419
420
421
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
422
        self.fill_ids = None
423
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
424
        self.input_embeds = input_embeds
425

Lianmin Zheng's avatar
Lianmin Zheng committed
426
        # Sampling info
427
428
429
430
431
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
432
        self.sampling_params = sampling_params
433
        self.custom_logit_processor = custom_logit_processor
434
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
435

436
        # Memory pool info
437
        self.req_pool_idx: Optional[int] = None
438

439
440
441
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
442
443
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
444
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
445
446
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
447
        self.stream = stream
448
        self.eos_token_ids = eos_token_ids
449

450
        # For incremental decoding
451
452
453
454
455
456
457
458
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
459
460
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
461
        self.decoded_text = ""
462

463
        # For multimodal inputs
Mick's avatar
Mick committed
464
        self.multimodal_inputs: Optional[MultimodalInputs] = None
465

466
        # Prefix info
467
        # The indices to kv cache for the shared prefix.
468
        self.prefix_indices = []
469
        # Number of tokens to run prefill.
470
        self.extend_input_len = 0
471
472
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
473
        self.last_node = None
474
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
475

476
477
478
479
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
480

481
482
483
        # For retraction
        self.is_retracted = False

484
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        self.return_logprob = return_logprob
486
        # Start index to compute logprob from.
487
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
488
        self.top_logprobs_num = top_logprobs_num
489
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
490
491
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
492

493
494
495
496
        # Latency Breakdown
        self.queue_time_start = None
        self.queue_time_end = None

497
        # Logprobs (return values)
498
499
500
501
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
502
503
504
505
506
507
508
509
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
512
513
514
515

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
516
517
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
518
519
520
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
521
522
523
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
524
        self.hidden_states: List[List[float]] = []
525

526
        # Embedding (return values)
527
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
528

529
        # Constrained decoding
530
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
531

532
        # The number of cached tokens that were already cached in the KV cache
533
        self.cached_tokens = 0
534
        self.already_computed = 0
535

536
537
538
539
540
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
541
        # For disaggregation
542
        self.bootstrap_host: str = bootstrap_host
543
        self.bootstrap_port: Optional[int] = bootstrap_port
544
        self.bootstrap_room: Optional[int] = bootstrap_room
545
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

560
561
562
563
564
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1

565
566
567
568
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

569
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
570
571
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
572
        else:
Mick's avatar
Mick committed
573
            self.multimodal_inputs.merge(image_inputs)
574

575
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
576
        # Whether request reached finished condition
577
578
        return self.finished_reason is not None

579
580
581
582
583
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
584
        self.fill_ids = self.origin_input_ids + self.output_ids
585
        if tree_cache is not None:
586
            # tree cache is None if the prefix is not computed with tree cache.
587
588
589
590
591
592
593
594
595
596
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
597
598
599
600
601
602
603
604
        elif enable_hierarchical_cache:
            # in case last_node is evicted during scheduling, we need to update the prefix_indices
            while self.last_node.evicted:
                self.prefix_indices = self.prefix_indices[
                    : -len(self.last_node.host_value)
                ]
                self.last_node = self.last_node.parent

605
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
606

607
    def adjust_max_prefix_ids(self):
608
609
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
610
611
612
613

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
614
615
616
617
618

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

619
        if self.return_logprob:
620
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
621

622
        max_prefix_len = max(max_prefix_len, 0)
623
        return self.fill_ids[:max_prefix_len]
624

Liangsheng Yin's avatar
Liangsheng Yin committed
625
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
626
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
627
628
629
630
631
632
633
634
635
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
636
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
637

638
    def check_finished(self):
639
        if self.finished():
640
641
            return

642
        if self.to_abort:
643
644
645
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
646
647
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
648
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
649
650
651
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
652
653
            return

654
        last_token_id = self.output_ids[-1]
655

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
673

674
        # Check stop strings
675
676
677
678
679
680
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
681
                if stop_str in tail_str or stop_str in self.decoded_text:
682
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
683
684
                    return

685
686
687
688
689
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
690
691
692
693
694
695
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
696
        self.already_computed = 0
697

Lianmin Zheng's avatar
Lianmin Zheng committed
698
    def __repr__(self):
699
        return (
700
701
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
702
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
703
704


705
706
707
bid = 0


708
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
709
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
710
    """Store all information of a batch on the scheduler."""
711

712
    # Request, memory pool, and cache
713
    reqs: List[Req]
714
    req_to_token_pool: ReqToTokenPool = None
715
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
716
    tree_cache: BasePrefixCache = None
717

718
    # Batch configs
719
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
720
    forward_mode: ForwardMode = None
721
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
722
723
724
725
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
726
727

    # Sampling info
728
    sampling_info: SamplingBatchInfo = None
729
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
730

731
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
732
    input_ids: torch.Tensor = None  # shape: [b], int64
733
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
734
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
735
    seq_lens: torch.Tensor = None  # shape: [b], int64
736
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
737
738
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
739

740
741
742
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
743
744
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
745
    global_num_tokens_for_logprob: Optional[List[int]] = None
746
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
747

748
    # For processing logprobs
749
    return_logprob: bool = False
750
    top_logprobs_nums: Optional[List[int]] = None
751
    token_ids_logprobs: Optional[List[List[int]]] = None
752

Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
756
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

757
758
759
760
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
761
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
762
    extend_logprob_start_lens: List[int] = None
763
764
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
765

Lianmin Zheng's avatar
Lianmin Zheng committed
766
    # For encoder-decoder architectures
767
768
769
770
771
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

772
773
774
    # Stream
    has_stream: bool = False

775
776
    # Has grammar
    has_grammar: bool = False
777

778
    # Device
779
780
    device: str = "cuda"

781
    # Speculative decoding
782
    spec_algorithm: SpeculativeAlgorithm = None
783
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
784

785
786
787
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

788
789
790
    # Whether to return hidden states
    return_hidden_states: bool = False

791
    @classmethod
792
793
    def init_new(
        cls,
794
        reqs: List[Req],
795
        req_to_token_pool: ReqToTokenPool,
796
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
797
798
799
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
800
        spec_algorithm: SpeculativeAlgorithm,
801
        enable_custom_logit_processor: bool,
802
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
803
804
        return_logprob = any(req.return_logprob for req in reqs)

805
806
807
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
808
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
809
            tree_cache=tree_cache,
810
            model_config=model_config,
811
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
812
            return_logprob=return_logprob,
813
            has_stream=any(req.stream for req in reqs),
814
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
815
            device=req_to_token_pool.device,
816
            spec_algorithm=spec_algorithm,
817
            enable_custom_logit_processor=enable_custom_logit_processor,
818
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
819
820
        )

821
    def batch_size(self):
822
        return len(self.reqs)
823

Lianmin Zheng's avatar
Lianmin Zheng committed
824
825
826
    def is_empty(self):
        return len(self.reqs) == 0

827
    def alloc_req_slots(self, num_reqs: int):
828
829
830
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
831
832
833
834
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
835
836
837
            )
        return req_pool_indices

838
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
841
842
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

843
844
845
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

846
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
847
848
849
850
851
852
853
854
855
856
857
858
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

859
860
861
862
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864
865
866
867
868
869

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
870
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
871
872
873
874
875
876
877
878
879
880
881
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
882

883
884
885
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
886
887
888
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
889
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
892
893
894
895
896
897
898
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
899
900
901
902
903

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
904
905
906
907
908

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
909
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
910
    ):
911
912
913
914
915
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
916
917
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
918
                )
919

920
921
922
923
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
924
925
926
927
928
929
930
931
932
933
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
934
935
936
937
938

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
939

940
941
942
943
944
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
945
            im = req.multimodal_inputs
946
947
948
949
950
951
952
953
954
955
956
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

957
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
958
959
960
961
962
963
964
965
966
967
968
969
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
970
                # NOTE: the encoder part should be considered as a whole
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
988
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
989
990
            self.device, non_blocking=True
        )
991
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
992
993
994
995
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
996
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
997
998
999
1000
1001
1002
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1003
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1004
1005
1006
1007
1008
1009
1010
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

1011
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1012
1013
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1014
        # Allocate req slots
1015
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1016
1017
1018
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
        reqs = self.reqs
1020
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1021
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1022
1023
1024
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1025

Lianmin Zheng's avatar
Lianmin Zheng committed
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1039

Lianmin Zheng's avatar
Lianmin Zheng committed
1040
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1041
        input_embeds = []
1042
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
1043

Lianmin Zheng's avatar
Lianmin Zheng committed
1044
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1045
            req.req_pool_idx = req_pool_indices[i]
1046
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1047

1048
            if pre_len > 0:
1049
1050
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1051
                )
1052

Rin Intachuen's avatar
Rin Intachuen committed
1053
1054
1055
1056
1057
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1058
1059
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1060
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1116

Lianmin Zheng's avatar
Lianmin Zheng committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1130
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
1132
1133
1134
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1135
1136
1137
1138
1139
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1140
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1141

1142
1143
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1144
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1145

1146
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
1148
1149
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1150
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1151

1152
        # Write to req_to_token_pool
1153
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1154
1155
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1156
1157
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1158
1159
1160
1161
1162
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1163
1164
1165
1166
1167
1168
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1169
1170
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1171
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1172
                pt += extend_lens[i]
1173

1174
1175
1176
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1177
        # Build sampling info
1178
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1179
1180
            self,
            self.model_config.vocab_size,
1181
        )
1182

1183
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1184
        self.forward_mode = ForwardMode.MIXED
1185
        running_bs = running_batch.batch_size()
1186
1187
1188
1189
1190

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1191
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1192
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1193

1194
        self.merge_batch(running_batch)
1195
1196
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1197

1198
1199
1200
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1201
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1202
        self.prefix_lens.extend(
1203
            [
1204
                len(r.origin_input_ids) + len(r.output_ids) + delta
1205
1206
1207
                for r in running_batch.reqs
            ]
        )
1208
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1209
1210
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1211
        self.extend_logprob_start_lens.extend([0] * running_bs)
1212

1213
1214
1215
1216
1217
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
        return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1218

1219
1220
1221
1222
1223
1224
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1225

1226
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1227
1228
            return True

1229
1230
1231
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1232

1233
    def retract_decode(self, server_args: ServerArgs):
1234
        """Retract the decoding requests when there is not enough memory."""
1235
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1236
1237

        # TODO(lsyin): improve retraction policy for radix cache
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1263

Lianmin Zheng's avatar
Lianmin Zheng committed
1264
1265
1266
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1267
        while (
1268
            self.token_to_kv_pool_allocator.available_size()
1269
            < get_required_tokens(len(sorted_indices))
1270
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1271
1272
1273
1274
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1275
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1276
1277
1278
                ), "No space left for only one request"
                break

1279
            first_iter = False
1280
1281
1282
1283
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1284
1285
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1286
1287
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1288
                ]
1289
                self.token_to_kv_pool_allocator.free(token_indices)
1290
                self.req_to_token_pool.free(req.req_pool_idx)
1291
1292
            else:
                # TODO: apply more fine-grained retraction
1293
                last_uncached_pos = (
1294
1295
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1296
1297
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1298
                ]
1299
                self.token_to_kv_pool_allocator.free(token_indices)
1300
                self.req_to_token_pool.free(req.req_pool_idx)
1301
1302
1303
1304
1305
1306
1307

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1308
                    - self.token_to_kv_pool_allocator.available_size()
1309
1310
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1311
                self.tree_cache.evict(residual_size)
1312

1313
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1314

1315
        self.filter_batch(keep_indices=sorted_indices)
1316

Liangsheng Yin's avatar
Liangsheng Yin committed
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1327

1328
1329
1330
1331
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1332
1333
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1334
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1335
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1336
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1337
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1338
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1339
        self.extend_num_tokens = 0
1340
1341
1342
1343
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1344

1345
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1346
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
1348
        bs = len(self.reqs)

1349
        if self.spec_algorithm.is_eagle():
1350
1351
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1352
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1353

1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1377
        # Update fields
1378
1379
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1380

1381
1382
1383
1384
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
            locs = self.seq_lens.clone()
1386

1387
        if self.enable_overlap:
1388
1389
1390
1391
1392
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1393
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1394

Lianmin Zheng's avatar
Lianmin Zheng committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1410
1411
    def filter_batch(
        self,
1412
        chunked_req_to_exclude: Optional[Req] = None,
1413
1414
1415
1416
1417
1418
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1419
1420
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1421
1422
1423
            ]

        if keep_indices is None or len(keep_indices) == 0:
1424
1425
1426
1427
            # Filter out all requests
            self.reqs = []
            return

1428
        if len(keep_indices) == len(self.reqs):
1429
1430
1431
            # No need to filter
            return

1432
1433
1434
1435
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1436
        if self.model_config.is_encoder_decoder:
1437
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1438
1439
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1440
        self.reqs = [self.reqs[i] for i in keep_indices]
1441
1442
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1443
        self.out_cache_loc = None
1444
        self.seq_lens_sum = self.seq_lens.sum().item()
1445
        self.output_ids = self.output_ids[keep_indices_device]
1446
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1447
        if self.return_logprob:
1448
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1449
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1450
1451
        else:
            self.top_logprobs_nums = None
1452
            self.token_ids_logprobs = None
1453

1454
        self.has_stream = any(req.stream for req in self.reqs)
1455
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1456

1457
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1458
        if self.spec_info:
1459
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1460

1461
    def merge_batch(self, other: "ScheduleBatch"):
1462
1463
1464
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1465
        self.sampling_info.merge_batch(other.sampling_info)
1466

1467
1468
1469
1470
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1471
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1472
1473
            [self.req_pool_indices, other.req_pool_indices]
        )
1474
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1475
        self.out_cache_loc = None
1476
        self.seq_lens_sum += other.seq_lens_sum
1477
        if self.output_ids is not None:
1478
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1479
1480
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1481
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1482
1483
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1484
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1485
1486
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1487
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1488
        self.reqs.extend(other.reqs)
1489

1490
1491
1492
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1493
        self.return_hidden_states |= other.return_hidden_states
1494

1495
1496
1497
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1498
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1499
        if self.forward_mode.is_decode_or_idle():
1500
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1501
1502
1503
1504
1505
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1506
1507
        # Create seq_lens_cpu when needed
        if (
1508
1509
1510
1511
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1512
            or global_server_args_dict["attention_backend"] == "flashmla"
1513
1514
1515
1516
1517
1518
            or global_server_args_dict["attention_backend"] == "fa3"
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1519
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1520
1521
1522
1523
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1524

1525
1526
        global bid
        bid += 1
1527
        return ModelWorkerBatch(
1528
            bid=bid,
1529
1530
1531
1532
1533
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1534
            seq_lens_sum=self.seq_lens_sum,
1535
1536
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1537
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1538
            global_num_tokens=self.global_num_tokens,
1539
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1540
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1541
            seq_lens_cpu=seq_lens_cpu,
1542
            extend_num_tokens=self.extend_num_tokens,
1543
1544
1545
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
Mick's avatar
Mick committed
1546
            multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1547
1548
1549
1550
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1551
            lora_paths=[req.lora_path for req in self.reqs],
1552
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1553
            input_embeds=self.input_embeds,
1554
1555
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1556
            capture_hidden_mode=(
1557
                CaptureHiddenMode.FULL
1558
                if self.return_hidden_states
1559
1560
1561
1562
1563
1564
1565
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1566
            ),
1567
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1568
1569
        )

1570
    def copy(self):
1571
        # Only contain fields that will be used by process_batch_result
1572
1573
        return ScheduleBatch(
            reqs=self.reqs,
1574
            model_config=self.model_config,
1575
            forward_mode=self.forward_mode,
1576
1577
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1578
            decoding_reqs=self.decoding_reqs,
1579
            spec_algorithm=self.spec_algorithm,
1580
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1581
1582
1583
1584
1585
1586
1587
1588
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1589

1590
@dataclasses.dataclass
1591
class ModelWorkerBatch:
1592
1593
    # The batch id
    bid: int
1594
1595
1596
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1597
    input_ids: torch.Tensor
1598
1599
1600
1601
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1602
    seq_lens_cpu: Optional[torch.Tensor]
1603
    # The indices of output tokens in the token_to_kv_pool_allocator
1604
1605
    out_cache_loc: torch.Tensor

1606
1607
1608
    # The sum of all sequence lengths
    seq_lens_sum: int

1609
1610
1611
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1612
    token_ids_logprobs: Optional[List[List[int]]]
1613

Ke Bao's avatar
Ke Bao committed
1614
1615
    # For DP attention
    global_num_tokens: Optional[List[int]]
1616
    global_num_tokens_for_logprob: Optional[List[int]]
1617
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1618

1619
    # For extend
1620
    extend_num_tokens: Optional[int]
1621
1622
1623
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1624
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1625
1626

    # For multimodal
Mick's avatar
Mick committed
1627
    multimodal_inputs: Optional[List[MultimodalInputs]]
1628

1629
1630
1631
1632
1633
1634
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1635
1636
1637
1638
1639
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1640

Rin Intachuen's avatar
Rin Intachuen committed
1641
1642
1643
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1644
    # Speculative decoding
1645
    spec_algorithm: SpeculativeAlgorithm = None
1646
1647
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1648
    capture_hidden_mode: CaptureHiddenMode = None
1649

1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1668
1669
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1686
1687
1688
1689
1690
1691
1692
1693
1694


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )