schedule_batch.py 60.9 KB
Newer Older
1
2
from __future__ import annotations

Mick's avatar
Mick committed
3
import hashlib
Mick's avatar
Mick committed
4
5
from enum import Enum, auto

6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
19
20
21
22
23
24
25
26
27
28
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
29
30
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
31
32
33
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
34

35
import copy
36
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
41
import torch
42
43
import triton
import triton.language as tl
44

Liangsheng Yin's avatar
Liangsheng Yin committed
45
from sglang.global_config import global_config
46
from sglang.srt.configs.model_config import ModelConfig
47
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
Byron Hsu's avatar
Byron Hsu committed
48
49
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
50
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
51
from sglang.srt.mem_cache.chunk_cache import ChunkCache
52
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
54
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
55
from sglang.srt.sampling.sampling_params import SamplingParams
56
from sglang.srt.server_args import ServerArgs
Mick's avatar
Mick committed
57
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
58

59
if TYPE_CHECKING:
60
61
62
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
63
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
64

65
66
# Put some global args for easy access
global_server_args_dict = {
67
68
69
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
70
    "disable_mla": ServerArgs.disable_mla,
71
    "torchao_config": ServerArgs.torchao_config,
72
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
73
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
74
    "enable_ep_moe": ServerArgs.enable_ep_moe,
75
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
76
    "deepep_mode": ServerArgs.deepep_mode,
77
    "device": ServerArgs.device,
78
79
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
lukec's avatar
lukec committed
80
    "enable_flashmla": ServerArgs.enable_flashmla,
81
    "disable_radix_cache": ServerArgs.disable_radix_cache,
82
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
83
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
84
85
    "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
    "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
86
87
}

Ying Sheng's avatar
Ying Sheng committed
88
89
90
logger = logging.getLogger(__name__)


91
92
93
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
94

95
    def to_json(self):
96
        raise NotImplementedError()
97
98
99


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
100
    def __init__(self, matched: Union[int, List[int]]):
101
102
103
        super().__init__()
        self.matched = matched

104
105
106
107
108
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
109
110


111
112
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
113
        super().__init__()
114
        self.matched = matched
115

116
117
118
119
120
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
121
122


123
124
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
125
        super().__init__()
126
        self.length = length
127

128
129
130
131
132
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
133
134
135


class FINISH_ABORT(BaseFinishReason):
136
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
137
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.message = message
139
140
        self.status_code = status_code
        self.err_type = err_type
141

142
143
144
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
145
            "message": self.message,
146
147
            "status_code": self.status_code,
            "err_type": self.err_type,
148
        }
149

Lianmin Zheng's avatar
Lianmin Zheng committed
150

Mick's avatar
Mick committed
151
152
153
154
155
156
157
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


158
@dataclasses.dataclass
Mick's avatar
Mick committed
159
160
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
161
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
162
    """
163

Mick's avatar
Mick committed
164
165
166
167
168
169
170
171
172
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
173
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    # the real data, pixel_values or audio_features
    # data: Union[List[torch.Tensor], List[np.array]]
    pixel_values: Union[torch.Tensor, np.array] = None
    image_grid_thws: Union[torch.Tensor, np.array] = None
    video_grid_thws: Union[torch.Tensor, np.array] = None

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

    audio_features: Union[torch.Tensor, np.array] = None
    audio_feature_lens: Optional[List[torch.Tensor]] = None

    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
199
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
200
201
        """

Mick's avatar
Mick committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)

            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

            if tensor.is_cuda:
                tensor_cpu = torch.frombuffer(
                    tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
                ).clone()
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
234

Mick's avatar
Mick committed
235
236
        def hash_feature(f):
            if isinstance(f, list):
237
238
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
239
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
240
241
242
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
243
244
245
246
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        if self.is_audio():
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
        return (
            self.modality == Modality.AUDIO
        ) and not MultimodalDataItem.is_empty_list(self.audio_features)

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

    def is_video(self):
        return (
            self.modality == Modality.VIDEO
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

    def validate(self):
        ...
        # TODO


@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
282
    image_pad_len: Optional[list] = None
283
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
284

Yineng Zhang's avatar
Yineng Zhang committed
285
    # QWen2-VL related
286
    mrope_position_delta: Optional[torch.Tensor] = None
287

Mick's avatar
Mick committed
288
    # image
Mick's avatar
Mick committed
289
    im_token_id: Optional[int] = None
290
291
292
293
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
294
295
296

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
297

Mick's avatar
Mick committed
298
299
300
301
    # audio
    audio_start_id: Optional[torch.Tensor] = None
    audio_end_id: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
302
    @staticmethod
303
    def from_dict(obj: dict):
Mick's avatar
Mick committed
304
        ret = MultimodalInputs(
Mick's avatar
Mick committed
305
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
306
        )
307

Mick's avatar
Mick committed
308
309
310
311
312
313
314
315
316
        assert isinstance(ret.mm_items, list)
        ret.mm_items = [
            item
            for item in ret.mm_items
            if item.is_audio() or item.is_image() or item.is_video()
        ]

        assert len(ret.mm_items) != 0

317
318
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
319
320
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
321
322
        for item in ret.mm_items:
            item.set_pad_value()
323
324
325

        optional_args = [
            "modalities",
326
            "im_token_id",
Mick's avatar
Mick committed
327
328
329
330
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
331
332
            "audio_start_id",
            "audio_end_id",
333
334
335
336
337
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
338
339
        return ret

Mick's avatar
Mick committed
340
341
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
342
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
343
344
345

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
346
347
348
349
        return any(item.is_audio() for item in self.mm_items)

    def collect_image_inputs(self) -> List[torch.Tensor]:
        return [item.pixel_values for item in self.mm_items if item.is_image()]
Mick's avatar
Mick committed
350
351

    def merge(self, other: MultimodalInputs):
352
353
354
        """
        merge image inputs when requests are being merged
        """
355

356
357
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
358
359
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
360

361
        # args needed to be merged
362
        optional_args = [
Mick's avatar
Mick committed
363
            "mm_items",
364
            "image_pad_len",
365
366
        ]
        for arg in optional_args:
367
368
369
370
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
        # other args would be kept intact
371

Liangsheng Yin's avatar
Liangsheng Yin committed
372

Lianmin Zheng's avatar
Lianmin Zheng committed
373
class Req:
374
    """The input and output status of a request."""
375

376
377
378
379
380
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
381
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
384
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
385
        stream: bool = False,
386
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
387
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
388
        input_embeds: Optional[List[List[float]]] = None,
389
        session_id: Optional[str] = None,
390
        custom_logit_processor: Optional[str] = None,
391
        return_hidden_states: bool = False,
392
        eos_token_ids: Optional[Set[int]] = None,
393
    ):
394
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
395
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
396
        self.origin_input_text = origin_input_text
397
398
399
400
401
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
402
        self.origin_input_ids = origin_input_ids
403
404
405
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
406
        self.fill_ids = None
407
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
408
        self.input_embeds = input_embeds
409

Lianmin Zheng's avatar
Lianmin Zheng committed
410
        # Sampling info
411
412
413
414
415
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
416
        self.sampling_params = sampling_params
417
        self.custom_logit_processor = custom_logit_processor
418
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
419

420
        # Memory pool info
421
        self.req_pool_idx: Optional[int] = None
422

423
424
425
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
426
427
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
428
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
431
        self.stream = stream
432
        self.eos_token_ids = eos_token_ids
433

434
        # For incremental decoding
435
436
437
438
439
440
441
442
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
443
444
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
445
        self.decoded_text = ""
446

447
        # For multimodal inputs
Mick's avatar
Mick committed
448
        self.multimodal_inputs: Optional[MultimodalInputs] = None
449

450
        # Prefix info
451
        # The indices to kv cache for the shared prefix.
452
        self.prefix_indices = []
453
        # Number of tokens to run prefill.
454
        self.extend_input_len = 0
455
456
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
457
        self.last_node = None
458
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
459

460
461
462
463
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
464

465
466
467
        # For retraction
        self.is_retracted = False

468
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
469
        self.return_logprob = return_logprob
470
        # Start index to compute logprob from.
471
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
472
        self.top_logprobs_num = top_logprobs_num
473
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
474
475
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
476

477
478
479
480
        # Latency Breakdown
        self.queue_time_start = None
        self.queue_time_end = None

481
        # Logprobs (return values)
482
483
484
485
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
486
487
488
489
490
491
492
493
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
498
499

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
500
501
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
502
503
504
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
505
506
507
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
508
        self.hidden_states: List[List[float]] = []
509

510
        # Embedding (return values)
511
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
512

513
        # Constrained decoding
514
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
515

516
        # The number of cached tokens that were already cached in the KV cache
517
        self.cached_tokens = 0
518
        self.already_computed = 0
519

520
521
522
523
524
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
        # For disaggregation
        self.bootstrap_host: str = "0.0.0.0"
        self.bootstrap_room: Optional[int] = None
        self.disagg_kv_sender: Optional[KVSender] = None

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

543
544
545
546
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

547
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
548
549
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
550
        else:
Mick's avatar
Mick committed
551
            self.multimodal_inputs.merge(image_inputs)
552

553
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
554
        # Whether request reached finished condition
555
556
        return self.finished_reason is not None

557
558
559
560
561
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
562
        self.fill_ids = self.origin_input_ids + self.output_ids
563
        if tree_cache is not None:
564
            # tree cache is None if the prefix is not computed with tree cache.
565
566
567
568
569
570
571
572
573
574
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
575
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
576

577
    def adjust_max_prefix_ids(self):
578
579
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
580
581
582
583

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
584
585
586
587
588

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

589
        if self.return_logprob:
590
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
591

592
        max_prefix_len = max(max_prefix_len, 0)
593
        return self.fill_ids[:max_prefix_len]
594

Liangsheng Yin's avatar
Liangsheng Yin committed
595
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
596
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
597
598
599
600
601
602
603
604
605
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
606
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
607

608
    def check_finished(self):
609
        if self.finished():
610
611
            return

612
        if self.to_abort:
613
614
615
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
616
617
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
618
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
619
620
621
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
622
623
            return

624
        last_token_id = self.output_ids[-1]
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
643

644
        # Check stop strings
645
646
647
648
649
650
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
651
                if stop_str in tail_str or stop_str in self.decoded_text:
652
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
653
654
                    return

655
656
657
658
659
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
660
661
662
663
664
665
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
666
        self.already_computed = 0
667

Lianmin Zheng's avatar
Lianmin Zheng committed
668
    def __repr__(self):
669
        return (
670
671
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
672
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
673
674


675
676
677
bid = 0


678
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
679
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
680
    """Store all information of a batch on the scheduler."""
681

682
    # Request, memory pool, and cache
683
    reqs: List[Req]
684
    req_to_token_pool: ReqToTokenPool = None
685
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
686
    tree_cache: BasePrefixCache = None
687

688
    # Batch configs
689
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
690
    forward_mode: ForwardMode = None
691
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
692
693
694
695
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
696
697

    # Sampling info
698
    sampling_info: SamplingBatchInfo = None
699
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
700

701
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
702
    input_ids: torch.Tensor = None  # shape: [b], int64
703
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
704
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
705
    seq_lens: torch.Tensor = None  # shape: [b], int64
706
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
707
708
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
709

710
711
712
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
713
714
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
715
    global_num_tokens_for_logprob: Optional[List[int]] = None
716
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
717

718
    # For processing logprobs
719
    return_logprob: bool = False
720
    top_logprobs_nums: Optional[List[int]] = None
721
    token_ids_logprobs: Optional[List[List[int]]] = None
722

Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
725
726
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

727
728
729
730
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
731
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
732
    extend_logprob_start_lens: List[int] = None
733
734
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
735

Lianmin Zheng's avatar
Lianmin Zheng committed
736
    # For encoder-decoder architectures
737
738
739
740
741
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

742
743
744
    # Stream
    has_stream: bool = False

745
746
    # Has grammar
    has_grammar: bool = False
747

748
    # Device
749
750
    device: str = "cuda"

751
    # Speculative decoding
752
    spec_algorithm: SpeculativeAlgorithm = None
753
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
754

755
756
757
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

758
759
760
    # Whether to return hidden states
    return_hidden_states: bool = False

761
    @classmethod
762
763
    def init_new(
        cls,
764
        reqs: List[Req],
765
        req_to_token_pool: ReqToTokenPool,
766
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
767
768
769
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
770
        spec_algorithm: SpeculativeAlgorithm,
771
        enable_custom_logit_processor: bool,
772
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
773
774
        return_logprob = any(req.return_logprob for req in reqs)

775
776
777
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
778
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
779
            tree_cache=tree_cache,
780
            model_config=model_config,
781
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
782
            return_logprob=return_logprob,
783
            has_stream=any(req.stream for req in reqs),
784
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
785
            device=req_to_token_pool.device,
786
            spec_algorithm=spec_algorithm,
787
            enable_custom_logit_processor=enable_custom_logit_processor,
788
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
789
790
        )

791
    def batch_size(self):
792
        return len(self.reqs)
793

Lianmin Zheng's avatar
Lianmin Zheng committed
794
795
796
    def is_empty(self):
        return len(self.reqs) == 0

797
    def alloc_req_slots(self, num_reqs: int):
798
799
800
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
801
802
803
804
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
805
806
807
            )
        return req_pool_indices

808
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
809
810
811
812
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

813
814
815
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

816
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
817
818
819
820
821
822
823
824
825
826
827
828
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

829
830
831
832
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
833
834
835
836
837
838
839

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
840
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
841
842
843
844
845
846
847
848
849
850
851
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
852

853
854
855
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
856
857
858
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
859
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
860
861
862
863
864
865
866
867
868
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
869
870
871
872
873

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
874
875
876
877
878

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
879
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
880
    ):
881
882
883
884
885
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
886
887
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
888
                )
889

890
891
892
893
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
894
895
896
897
898
899
900
901
902
903
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
904
905
906
907
908

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
909

910
911
912
913
914
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
915
            im = req.multimodal_inputs
916
917
918
919
920
921
922
923
924
925
926
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

927
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
928
929
930
931
932
933
934
935
936
937
938
939
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
940
                # NOTE: the encoder part should be considered as a whole
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
958
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
959
960
            self.device, non_blocking=True
        )
961
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
962
963
964
965
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
966
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
967
968
969
970
971
972
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
973
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
974
975
976
977
978
979
980
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

981
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
982
983
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
984
        # Allocate req slots
985
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
986
987
988
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
989
        reqs = self.reqs
990
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
991
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
992
993
994
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
995

Lianmin Zheng's avatar
Lianmin Zheng committed
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1009

Lianmin Zheng's avatar
Lianmin Zheng committed
1010
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1011
        input_embeds = []
1012
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
1013

Lianmin Zheng's avatar
Lianmin Zheng committed
1014
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1015
            req.req_pool_idx = req_pool_indices[i]
1016
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1017

1018
            if pre_len > 0:
1019
1020
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1021
                )
1022

Rin Intachuen's avatar
Rin Intachuen committed
1023
1024
1025
1026
1027
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1028
1029
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1030
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1031

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1086

Lianmin Zheng's avatar
Lianmin Zheng committed
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1100
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1101
1102
1103
1104
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1105
1106
1107
1108
1109
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1110
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1111

1112
1113
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1114
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1115

1116
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1117
1118
1119
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1120
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1121

1122
        # Write to req_to_token_pool
1123
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
1125
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1126
1127
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
1131
1132
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1133
1134
1135
1136
1137
1138
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1141
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1142
                pt += extend_lens[i]
1143

1144
1145
1146
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1147
        # Build sampling info
1148
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1149
1150
            self,
            self.model_config.vocab_size,
1151
        )
1152

1153
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1154
        self.forward_mode = ForwardMode.MIXED
1155
        running_bs = running_batch.batch_size()
1156
1157
1158
1159
1160

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1161
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1162
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1163

1164
        self.merge_batch(running_batch)
1165
1166
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1167

1168
1169
1170
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1171
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1172
        self.prefix_lens.extend(
1173
            [
1174
                len(r.origin_input_ids) + len(r.output_ids) + delta
1175
1176
1177
                for r in running_batch.reqs
            ]
        )
1178
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1179
1180
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1181
        self.extend_logprob_start_lens.extend([0] * running_bs)
1182

1183
1184
1185
1186
1187
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
        return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1188

1189
1190
1191
1192
1193
1194
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1195

1196
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1197
1198
            return True

1199
1200
1201
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1202

1203
    def retract_decode(self, server_args: ServerArgs):
1204
        """Retract the decoding requests when there is not enough memory."""
1205
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1206
1207

        # TODO(lsyin): improve retraction policy for radix cache
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1233

Lianmin Zheng's avatar
Lianmin Zheng committed
1234
1235
1236
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1237
        while (
1238
            self.token_to_kv_pool_allocator.available_size()
1239
            < get_required_tokens(len(sorted_indices))
1240
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1241
1242
1243
1244
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1245
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1246
1247
1248
                ), "No space left for only one request"
                break

1249
            first_iter = False
1250
1251
1252
1253
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1254
1255
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1256
1257
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1258
                ]
1259
                self.token_to_kv_pool_allocator.free(token_indices)
1260
                self.req_to_token_pool.free(req.req_pool_idx)
1261
1262
            else:
                # TODO: apply more fine-grained retraction
1263
                last_uncached_pos = (
1264
1265
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1266
1267
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1268
                ]
1269
                self.token_to_kv_pool_allocator.free(token_indices)
1270
                self.req_to_token_pool.free(req.req_pool_idx)
1271
1272
1273
1274
1275
1276
1277

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1278
                    - self.token_to_kv_pool_allocator.available_size()
1279
1280
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1281
                self.tree_cache.evict(residual_size)
1282

1283
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1284

1285
        self.filter_batch(keep_indices=sorted_indices)
1286

Liangsheng Yin's avatar
Liangsheng Yin committed
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1297

1298
1299
1300
1301
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1302
1303
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1304
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1305
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1306
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1307
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1308
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1309
        self.extend_num_tokens = 0
1310
1311
1312
1313
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1314

1315
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1316
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
1318
        bs = len(self.reqs)

1319
        if self.spec_algorithm.is_eagle():
1320
1321
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1322
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1323

1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1347
        # Update fields
1348
1349
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1350

1351
1352
1353
1354
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
            locs = self.seq_lens.clone()
1356

1357
        if self.enable_overlap:
1358
1359
1360
1361
1362
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1363
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1364

Lianmin Zheng's avatar
Lianmin Zheng committed
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1380
1381
    def filter_batch(
        self,
1382
        chunked_req_to_exclude: Optional[Req] = None,
1383
1384
1385
1386
1387
1388
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1389
1390
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1391
1392
1393
            ]

        if keep_indices is None or len(keep_indices) == 0:
1394
1395
1396
1397
            # Filter out all requests
            self.reqs = []
            return

1398
        if len(keep_indices) == len(self.reqs):
1399
1400
1401
            # No need to filter
            return

1402
1403
1404
1405
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1406
        if self.model_config.is_encoder_decoder:
1407
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1408
1409
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1410
        self.reqs = [self.reqs[i] for i in keep_indices]
1411
1412
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1413
        self.out_cache_loc = None
1414
        self.seq_lens_sum = self.seq_lens.sum().item()
1415
        self.output_ids = self.output_ids[keep_indices_device]
1416
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1417
        if self.return_logprob:
1418
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1419
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1420
1421
        else:
            self.top_logprobs_nums = None
1422
            self.token_ids_logprobs = None
1423

1424
        self.has_stream = any(req.stream for req in self.reqs)
1425
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1426

1427
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1428
        if self.spec_info:
1429
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1430

1431
    def merge_batch(self, other: "ScheduleBatch"):
1432
1433
1434
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1435
        self.sampling_info.merge_batch(other.sampling_info)
1436

1437
1438
1439
1440
1441
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

1442
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1443
1444
            [self.req_pool_indices, other.req_pool_indices]
        )
1445
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1446
        self.out_cache_loc = None
1447
        self.seq_lens_sum += other.seq_lens_sum
1448
        if self.output_ids is not None:
1449
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1450
1451
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1452
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1453
1454
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1455
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1456
1457
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1458
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1459
        self.reqs.extend(other.reqs)
1460

1461
1462
1463
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1464
        self.return_hidden_states |= other.return_hidden_states
1465

1466
1467
1468
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1469
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1470
        if self.forward_mode.is_decode_or_idle():
1471
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1472
1473
1474
1475
1476
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1477
1478
        # Create seq_lens_cpu when needed
        if (
1479
1480
1481
1482
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1483
1484
1485
1486
1487
1488
1489
            or global_server_args_dict["enable_flashmla"]
            or global_server_args_dict["attention_backend"] == "fa3"
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1490
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1491
1492
1493
1494
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1495

1496
1497
        global bid
        bid += 1
1498
        return ModelWorkerBatch(
1499
            bid=bid,
1500
1501
1502
1503
1504
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1505
            seq_lens_sum=self.seq_lens_sum,
1506
1507
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1508
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1509
            global_num_tokens=self.global_num_tokens,
1510
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1511
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1512
            seq_lens_cpu=seq_lens_cpu,
1513
            extend_num_tokens=self.extend_num_tokens,
1514
1515
1516
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
Mick's avatar
Mick committed
1517
            multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1518
1519
1520
1521
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1522
            lora_paths=[req.lora_path for req in self.reqs],
1523
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1524
            input_embeds=self.input_embeds,
1525
1526
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1527
            capture_hidden_mode=(
1528
                CaptureHiddenMode.FULL
1529
                if self.return_hidden_states
1530
1531
1532
1533
1534
1535
1536
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1537
            ),
1538
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1539
1540
        )

1541
    def copy(self):
1542
        # Only contain fields that will be used by process_batch_result
1543
1544
        return ScheduleBatch(
            reqs=self.reqs,
1545
            model_config=self.model_config,
1546
            forward_mode=self.forward_mode,
1547
1548
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1549
            decoding_reqs=self.decoding_reqs,
1550
            spec_algorithm=self.spec_algorithm,
1551
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1552
1553
1554
1555
1556
1557
1558
1559
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1560

1561
@dataclasses.dataclass
1562
class ModelWorkerBatch:
1563
1564
    # The batch id
    bid: int
1565
1566
1567
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1568
    input_ids: torch.Tensor
1569
1570
1571
1572
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1573
    seq_lens_cpu: Optional[torch.Tensor]
1574
    # The indices of output tokens in the token_to_kv_pool_allocator
1575
1576
    out_cache_loc: torch.Tensor

1577
1578
1579
    # The sum of all sequence lengths
    seq_lens_sum: int

1580
1581
1582
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1583
    token_ids_logprobs: Optional[List[List[int]]]
1584

Ke Bao's avatar
Ke Bao committed
1585
1586
    # For DP attention
    global_num_tokens: Optional[List[int]]
1587
    global_num_tokens_for_logprob: Optional[List[int]]
1588
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1589

1590
    # For extend
1591
    extend_num_tokens: Optional[int]
1592
1593
1594
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1595
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1596
1597

    # For multimodal
Mick's avatar
Mick committed
1598
    multimodal_inputs: Optional[List[MultimodalInputs]]
1599

1600
1601
1602
1603
1604
1605
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1606
1607
1608
1609
1610
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1611

Rin Intachuen's avatar
Rin Intachuen committed
1612
1613
1614
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1615
    # Speculative decoding
1616
    spec_algorithm: SpeculativeAlgorithm = None
1617
1618
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1619
    capture_hidden_mode: CaptureHiddenMode = None
1620

1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1639
1640
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1657
1658
1659
1660
1661
1662
1663
1664
1665


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )