schedule_batch.py 59.4 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import copy
33
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
34
import logging
35
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import torch
39
40
import triton
import triton.language as tl
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.global_config import global_config
43
from sglang.srt.configs.model_config import ModelConfig
44
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
Byron Hsu's avatar
Byron Hsu committed
45
46
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
47
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
48
from sglang.srt.mem_cache.chunk_cache import ChunkCache
49
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
50
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
51
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
from sglang.srt.sampling.sampling_params import SamplingParams
53
from sglang.srt.server_args import ServerArgs
54
from sglang.srt.utils import get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
55

56
if TYPE_CHECKING:
57
58
59
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
60
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
61

62
63
# Put some global args for easy access
global_server_args_dict = {
64
65
66
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
67
    "disable_mla": ServerArgs.disable_mla,
68
    "torchao_config": ServerArgs.torchao_config,
69
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
70
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
71
    "enable_ep_moe": ServerArgs.enable_ep_moe,
72
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
73
    "device": ServerArgs.device,
74
75
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
76
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
lukec's avatar
lukec committed
77
    "enable_flashmla": ServerArgs.enable_flashmla,
78
    "disable_radix_cache": ServerArgs.disable_radix_cache,
79
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
81
82
}

Ying Sheng's avatar
Ying Sheng committed
83
84
85
logger = logging.getLogger(__name__)


86
87
88
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
89

90
    def to_json(self):
91
        raise NotImplementedError()
92
93
94


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
95
    def __init__(self, matched: Union[int, List[int]]):
96
97
98
        super().__init__()
        self.matched = matched

99
100
101
102
103
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
104
105


106
107
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
108
        super().__init__()
109
        self.matched = matched
110

111
112
113
114
115
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
116
117


118
119
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
120
        super().__init__()
121
        self.length = length
122

123
124
125
126
127
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
128
129
130


class FINISH_ABORT(BaseFinishReason):
131
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
132
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
133
        self.message = message
134
135
        self.status_code = status_code
        self.err_type = err_type
136

137
138
139
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
140
            "message": self.message,
141
142
            "status_code": self.status_code,
            "err_type": self.err_type,
143
        }
144

Lianmin Zheng's avatar
Lianmin Zheng committed
145

146
@dataclasses.dataclass
Mick's avatar
Mick committed
147
class MultimodalInputs:
148
149
    """The image related inputs."""

150
    pixel_values: Union[torch.Tensor, np.array]
Mick's avatar
Mick committed
151
    data_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
154
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
155
156
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
157
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
158

159
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
162

Yineng Zhang's avatar
Yineng Zhang committed
163
    # QWen2-VL related
164
165
    # [num_of_images, t, h, w]
    image_grid_thws: torch.Tensor = None
166
    mrope_position_delta: Optional[torch.Tensor] = None
167
168
169
170
    # Qwen2-VL video related
    video_token_id: Optional[int] = None
    video_grid_thws: List[Tuple[int, int, int]] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
171

172
    # deepseek vl2 related
173
    images_emb_mask: Optional[List[torch.Tensor]] = None
174
175
    image_spatial_crop: Optional[List[torch.Tensor]] = None

176
177
    # The id of the single-image placeholder token
    im_token_id: Optional[torch.Tensor] = None
178

Mick's avatar
Mick committed
179
180
    # All the images in the batch should share the same special image
    # bound token ids.
181
182
183
184
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
185
    # [num_images, 2 (w, h)]
Mick's avatar
Mick committed
186
187
    tgt_sizes: Optional[list] = None

Mick's avatar
Mick committed
188
189
190
191
192
193
    # audio
    audio_start_id: Optional[torch.Tensor] = None
    audio_end_id: Optional[torch.Tensor] = None
    audio_features: Optional[List[torch.Tensor]] = None
    audio_feature_lens: Optional[List[torch.Tensor]] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
194
    @staticmethod
195
    def from_dict(obj: dict):
Mick's avatar
Mick committed
196
        ret = MultimodalInputs(
Liangsheng Yin's avatar
Liangsheng Yin committed
197
            pixel_values=obj["pixel_values"],
Mick's avatar
Mick committed
198
            data_hashes=obj["data_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
199
        )
200
201
202

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
203
204
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
205
        ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
206
207
208
209
210
211
212

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
213
            "images_emb_mask",
214
            "image_spatial_crop",
215
            "im_token_id",
Mick's avatar
Mick committed
216
217
218
219
220
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
Mick's avatar
Mick committed
221
222
223
224
            "audio_start_id",
            "audio_end_id",
            "audio_features",
            "audio_feature_lens",
225
226
227
228
229
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

230
231
232
233
234
235
236
        # validate
        assert (
            isinstance(ret.pixel_values, torch.Tensor)
            or isinstance(ret.pixel_values, np.ndarray)
            or isinstance(ret.pixel_values, list)
        )

Mick's avatar
Mick committed
237
238
        assert ret.audio_features is None or isinstance(ret.audio_features, list)

Liangsheng Yin's avatar
Liangsheng Yin committed
239
240
        return ret

Mick's avatar
Mick committed
241
242
243
244
245
246
247
248
249
    def contains_image_inputs(self) -> bool:
        """ """
        return self.pixel_values is not None and self.pixel_values != []

    def contains_audio_inputs(self) -> bool:
        """ """
        return self.audio_features is not None and self.audio_features != []

    def merge(self, other: MultimodalInputs):
250
251
252
        """
        merge image inputs when requests are being merged
        """
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        if isinstance(self.pixel_values, list):
            # in some rare cases, pixel values are list of patches with different shapes
            # e.g. minicpm
            self.pixel_values += other.pixel_values
        else:
            assert (
                self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
            ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
            self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

        # args would be stacked along first dim
        # usually these are already tensors
        stack_args = [
            # TODO: merge with image_grid_thws, basically the same thing
            "tgt_sizes",
            "image_spatial_crop",
        ]
        for arg in stack_args:
            if getattr(self, arg, None) is None:
                setattr(self, arg, getattr(other, arg, None))
            elif getattr(other, arg, None) is not None:
                # self and other both not None
                setattr(
                    self,
                    arg,
                    torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
                )

        if self.image_grid_thws is None:
            self.image_grid_thws = other.image_grid_thws
        elif other.image_grid_thws is not None:
            self.image_grid_thws = torch.concat(
                [self.image_grid_thws, other.image_grid_thws]
            )
287

288
289
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
290
291
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
292
293
294
        self.data_hashes += other.data_hashes
        self.pad_values = [x % (1 << 30) for x in self.data_hashes]

295
        # args needed to be merged
296
        optional_args = [
Mick's avatar
Mick committed
297
            "audio_features",
298
299
            "image_sizes",
            "image_offsets",
300
            "image_pad_len",
301
302
303
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
304
            "images_emb_mask",
305
306
        ]
        for arg in optional_args:
307
308
309
310
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
        # other args would be kept intact
311

Liangsheng Yin's avatar
Liangsheng Yin committed
312

Lianmin Zheng's avatar
Lianmin Zheng committed
313
class Req:
314
    """The input and output status of a request."""
315

316
317
318
319
320
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
321
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
324
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
325
        stream: bool = False,
326
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
327
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
328
        input_embeds: Optional[List[List[float]]] = None,
329
        session_id: Optional[str] = None,
330
        custom_logit_processor: Optional[str] = None,
331
        return_hidden_states: bool = False,
332
        eos_token_ids: Optional[Set[int]] = None,
333
    ):
334
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
335
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
336
        self.origin_input_text = origin_input_text
337
338
339
340
341
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
342
        self.origin_input_ids = origin_input_ids
343
344
345
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
346
        self.fill_ids = None
347
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
348
        self.input_embeds = input_embeds
349

Lianmin Zheng's avatar
Lianmin Zheng committed
350
        # Sampling info
351
352
353
354
355
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
356
        self.sampling_params = sampling_params
357
        self.custom_logit_processor = custom_logit_processor
358
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
359

360
        # Memory pool info
361
        self.req_pool_idx: Optional[int] = None
362

363
364
365
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
366
367
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
368
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
369
370
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        self.stream = stream
372
        self.eos_token_ids = eos_token_ids
373

374
        # For incremental decoding
375
376
377
378
379
380
381
382
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
383
384
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
385
        self.decoded_text = ""
386

387
        # For multimodal inputs
Mick's avatar
Mick committed
388
        self.multimodal_inputs: Optional[MultimodalInputs] = None
389

390
        # Prefix info
391
        # The indices to kv cache for the shared prefix.
392
        self.prefix_indices = []
393
        # Number of tokens to run prefill.
394
        self.extend_input_len = 0
395
396
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
397
        self.last_node = None
398
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
399

400
401
402
403
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
404

405
406
407
        # For retraction
        self.is_retracted = False

408
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
409
        self.return_logprob = return_logprob
410
        # Start index to compute logprob from.
411
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
412
        self.top_logprobs_num = top_logprobs_num
413
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
414
415
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
416

417
        # Logprobs (return values)
418
419
420
421
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
422
423
424
425
426
427
428
429
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
433
434
435

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
436
437
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
440
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
441
442
443
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
444
        self.hidden_states: List[List[float]] = []
445

446
        # Embedding (return values)
447
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
448

449
        # Constrained decoding
450
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
451

452
        # The number of cached tokens that were already cached in the KV cache
453
        self.cached_tokens = 0
454
        self.already_computed = 0
455

456
457
458
459
460
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        # For disaggregation
        self.bootstrap_host: str = "0.0.0.0"
        self.bootstrap_room: Optional[int] = None
        self.disagg_kv_sender: Optional[KVSender] = None

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

479
480
481
482
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

483
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
484
485
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
486
        else:
Mick's avatar
Mick committed
487
            self.multimodal_inputs.merge(image_inputs)
488

489
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
490
        # Whether request reached finished condition
491
492
        return self.finished_reason is not None

493
494
495
496
497
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
498
        self.fill_ids = self.origin_input_ids + self.output_ids
499
        if tree_cache is not None:
500
            # tree cache is None if the prefix is not computed with tree cache.
501
502
503
504
505
506
507
508
509
510
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
511
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
512

513
    def adjust_max_prefix_ids(self):
514
515
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
516
517
518
519

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
520
521
522
523
524

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

525
        if self.return_logprob:
526
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
527

528
        max_prefix_len = max(max_prefix_len, 0)
529
        return self.fill_ids[:max_prefix_len]
530

Liangsheng Yin's avatar
Liangsheng Yin committed
531
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
532
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
533
534
535
536
537
538
539
540
541
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
542
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
543

544
    def check_finished(self):
545
        if self.finished():
546
547
            return

548
        if self.to_abort:
549
550
551
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
552
553
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
554
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
555
556
557
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
558
559
            return

560
        last_token_id = self.output_ids[-1]
561

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
579

580
        # Check stop strings
581
582
583
584
585
586
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
587
                if stop_str in tail_str or stop_str in self.decoded_text:
588
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
589
590
                    return

591
592
593
594
595
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
596
597
598
599
600
601
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
602
        self.already_computed = 0
603

Lianmin Zheng's avatar
Lianmin Zheng committed
604
    def __repr__(self):
605
        return (
606
607
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
608
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
609
610


611
612
613
bid = 0


614
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
615
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
616
    """Store all information of a batch on the scheduler."""
617

618
    # Request, memory pool, and cache
619
    reqs: List[Req]
620
    req_to_token_pool: ReqToTokenPool = None
621
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
622
    tree_cache: BasePrefixCache = None
623

624
    # Batch configs
625
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
626
    forward_mode: ForwardMode = None
627
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
632
633

    # Sampling info
634
    sampling_info: SamplingBatchInfo = None
635
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
636

637
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
638
    input_ids: torch.Tensor = None  # shape: [b], int64
639
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
640
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
641
    seq_lens: torch.Tensor = None  # shape: [b], int64
642
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
645

646
647
648
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
649
650
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
651
    global_num_tokens_for_logprob: Optional[List[int]] = None
652
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
653

654
    # For processing logprobs
655
    return_logprob: bool = False
656
    top_logprobs_nums: Optional[List[int]] = None
657
    token_ids_logprobs: Optional[List[List[int]]] = None
658

Lianmin Zheng's avatar
Lianmin Zheng committed
659
660
661
662
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

663
664
665
666
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
667
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
668
    extend_logprob_start_lens: List[int] = None
669
670
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
671

Lianmin Zheng's avatar
Lianmin Zheng committed
672
    # For encoder-decoder architectures
673
674
675
676
677
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

678
679
680
    # Stream
    has_stream: bool = False

681
682
    # Has grammar
    has_grammar: bool = False
683

684
    # Device
685
686
    device: str = "cuda"

687
    # Speculative decoding
688
    spec_algorithm: SpeculativeAlgorithm = None
689
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
690

691
692
693
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

694
695
696
    # Whether to return hidden states
    return_hidden_states: bool = False

697
    @classmethod
698
699
    def init_new(
        cls,
700
        reqs: List[Req],
701
        req_to_token_pool: ReqToTokenPool,
702
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
703
704
705
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
706
        spec_algorithm: SpeculativeAlgorithm,
707
        enable_custom_logit_processor: bool,
708
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
709
710
        return_logprob = any(req.return_logprob for req in reqs)

711
712
713
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
714
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
715
            tree_cache=tree_cache,
716
            model_config=model_config,
717
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
718
            return_logprob=return_logprob,
719
            has_stream=any(req.stream for req in reqs),
720
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
721
            device=req_to_token_pool.device,
722
            spec_algorithm=spec_algorithm,
723
            enable_custom_logit_processor=enable_custom_logit_processor,
724
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
725
726
        )

727
    def batch_size(self):
728
        return len(self.reqs)
729

Lianmin Zheng's avatar
Lianmin Zheng committed
730
731
732
    def is_empty(self):
        return len(self.reqs) == 0

733
    def alloc_req_slots(self, num_reqs: int):
734
735
736
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
737
738
739
740
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
741
742
743
            )
        return req_pool_indices

744
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
745
746
747
748
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

749
750
751
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

752
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
756
757
758
759
760
761
762
763
764
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

765
766
767
768
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
771
772
773
774
775

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
776
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
777
778
779
780
781
782
783
784
785
786
787
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
788

789
790
791
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
794
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
795
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797
798
799
800
801
802
803
804
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
805
806
807
808
809

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
810
811
812
813
814

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
815
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
816
817
818
819
820
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
821
            if self.tree_cache is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
822
823
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
824
                )
825

826
827
828
829
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
830
831
832
833
834
835
836
837
838
839
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
840
841
842
843
844

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
845

846
847
848
849
850
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
851
            im = req.multimodal_inputs
852
853
854
855
856
857
858
859
860
861
862
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

863
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
864
865
866
867
868
869
870
871
872
873
874
875
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
876
                # NOTE: the encoder part should be considered as a whole
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
894
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
895
896
            self.device, non_blocking=True
        )
897
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
898
899
900
901
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
902
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
903
904
905
906
907
908
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
909
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
910
911
912
913
914
915
916
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

917
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
918
919
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
920
        # Allocate req slots
921
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
924
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
925
        reqs = self.reqs
926
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
927
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
928
929
930
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
931

Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
934
935
936
937
938
939
940
941
942
943
944
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
945

Lianmin Zheng's avatar
Lianmin Zheng committed
946
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
947
        input_embeds = []
948
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
949

Lianmin Zheng's avatar
Lianmin Zheng committed
950
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
951
            req.req_pool_idx = req_pool_indices[i]
952
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
953

954
            if pre_len > 0:
955
956
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
957
                )
958

Rin Intachuen's avatar
Rin Intachuen committed
959
960
961
962
963
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

964
965
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
966
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
967

968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1022

Lianmin Zheng's avatar
Lianmin Zheng committed
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1036
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1037
1038
1039
1040
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1041
1042
1043
1044
1045
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1046
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1047

1048
1049
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1050
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1051

1052
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1053
1054
1055
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1056
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1057

1058
        # Write to req_to_token_pool
1059
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1060
1061
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1062
1063
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1064
1065
1066
1067
1068
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1069
1070
1071
1072
1073
1074
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1075
1076
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1077
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1078
                pt += extend_lens[i]
1079

1080
1081
1082
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1083
        # Build sampling info
1084
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1085
1086
            self,
            self.model_config.vocab_size,
1087
        )
1088

1089
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1090
        self.forward_mode = ForwardMode.MIXED
1091
        running_bs = running_batch.batch_size()
1092
1093
1094
1095
1096

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1097
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1098
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1099

1100
        self.merge_batch(running_batch)
1101
1102
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1103

1104
1105
1106
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1107
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1108
        self.prefix_lens.extend(
1109
            [
1110
                len(r.origin_input_ids) + len(r.output_ids) + delta
1111
1112
1113
                for r in running_batch.reqs
            ]
        )
1114
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1115
1116
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1117
        self.extend_logprob_start_lens.extend([0] * running_bs)
1118

1119
1120
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
1121
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1122
1123
            return True

Lianmin Zheng's avatar
Lianmin Zheng committed
1124
        self.tree_cache.evict(bs)
1125

1126
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1127
1128
1129
1130
            return True

        return False

1131
    def retract_decode(self, server_args: ServerArgs):
1132
        """Retract the decoding requests when there is not enough memory."""
1133
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1134
1135

        # TODO(lsyin): improve retraction policy for radix cache
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1161

Lianmin Zheng's avatar
Lianmin Zheng committed
1162
1163
1164
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1165
        while (
1166
            self.token_to_kv_pool_allocator.available_size()
1167
            < get_required_tokens(len(sorted_indices))
1168
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1169
1170
1171
1172
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1173
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1174
1175
1176
                ), "No space left for only one request"
                break

1177
            first_iter = False
1178
1179
1180
1181
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1182
1183
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1184
1185
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1186
                ]
1187
                self.token_to_kv_pool_allocator.free(token_indices)
1188
                self.req_to_token_pool.free(req.req_pool_idx)
1189
1190
            else:
                # TODO: apply more fine-grained retraction
1191
1192
1193
1194
1195
                last_uncached_pos = (
                    (len(req.prefix_indices) + server_args.page_size - 1)
                    // server_args.page_size
                    * server_args.page_size
                )
1196
1197
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1198
                ]
1199
                self.token_to_kv_pool_allocator.free(token_indices)
1200
                self.req_to_token_pool.free(req.req_pool_idx)
1201
1202
1203
1204
1205
1206
1207

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1208
                    - self.token_to_kv_pool_allocator.available_size()
1209
1210
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1211
                self.tree_cache.evict(residual_size)
1212

1213
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1214

1215
        self.filter_batch(keep_indices=sorted_indices)
1216

Liangsheng Yin's avatar
Liangsheng Yin committed
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1227

1228
1229
1230
1231
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1232
1233
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1234
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1235
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1236
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1237
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1238
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1239
        self.extend_num_tokens = 0
1240
1241
1242
1243
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1244

1245
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1246
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1247
1248
        bs = len(self.reqs)

1249
        if self.spec_algorithm.is_eagle():
1250
1251
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1252
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1253

1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1277
        # Update fields
1278
1279
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1280

1281
1282
1283
1284
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1285
            locs = self.seq_lens.clone()
1286

1287
        if self.enable_overlap:
1288
1289
1290
1291
1292
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1293
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1294

Lianmin Zheng's avatar
Lianmin Zheng committed
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1310
1311
    def filter_batch(
        self,
1312
        chunked_req_to_exclude: Optional[Req] = None,
1313
1314
1315
1316
1317
1318
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1319
1320
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1321
1322
1323
            ]

        if keep_indices is None or len(keep_indices) == 0:
1324
1325
1326
1327
            # Filter out all requests
            self.reqs = []
            return

1328
        if len(keep_indices) == len(self.reqs):
1329
1330
1331
            # No need to filter
            return

1332
1333
1334
1335
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1336
        if self.model_config.is_encoder_decoder:
1337
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1338
1339
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1340
        self.reqs = [self.reqs[i] for i in keep_indices]
1341
1342
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1343
        self.out_cache_loc = None
1344
        self.seq_lens_sum = self.seq_lens.sum().item()
1345
        self.output_ids = self.output_ids[keep_indices_device]
1346
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1347
        if self.return_logprob:
1348
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1349
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1350
1351
        else:
            self.top_logprobs_nums = None
1352
            self.token_ids_logprobs = None
1353

1354
        self.has_stream = any(req.stream for req in self.reqs)
1355
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1356

1357
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1358
        if self.spec_info:
1359
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1360

1361
    def merge_batch(self, other: "ScheduleBatch"):
1362
1363
1364
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1365
        self.sampling_info.merge_batch(other.sampling_info)
1366

1367
1368
1369
1370
1371
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

1372
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1373
1374
            [self.req_pool_indices, other.req_pool_indices]
        )
1375
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1376
        self.out_cache_loc = None
1377
        self.seq_lens_sum += other.seq_lens_sum
1378
        if self.output_ids is not None:
1379
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1380
1381
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1382
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1383
1384
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1385
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1386
1387
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1388
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1389
        self.reqs.extend(other.reqs)
1390

1391
1392
1393
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1394
        self.return_hidden_states |= other.return_hidden_states
1395

1396
1397
1398
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1399
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1400
        if self.forward_mode.is_decode_or_idle():
lukec's avatar
lukec committed
1401
1402
1403
            if (
                global_server_args_dict["enable_flashinfer_mla"]
                or global_server_args_dict["enable_flashmla"]
1404
                or global_server_args_dict["attention_backend"] == "fa3"
lukec's avatar
lukec committed
1405
            ):
1406
1407
1408
                decode_seq_lens = self.seq_lens.cpu()
            else:
                decode_seq_lens = None
1409
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1410
        else:
1411
            decode_seq_lens = None
1412
1413
1414
1415
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1416
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1417
1418
1419
1420
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1421

1422
1423
        global bid
        bid += 1
1424
        return ModelWorkerBatch(
1425
            bid=bid,
1426
1427
1428
1429
1430
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1431
            seq_lens_sum=self.seq_lens_sum,
1432
1433
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1434
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1435
            global_num_tokens=self.global_num_tokens,
1436
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1437
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1438
            decode_seq_lens=decode_seq_lens,
1439
            extend_num_tokens=self.extend_num_tokens,
1440
1441
1442
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
Mick's avatar
Mick committed
1443
            multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1444
1445
1446
1447
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1448
            lora_paths=[req.lora_path for req in self.reqs],
1449
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1450
            input_embeds=self.input_embeds,
1451
1452
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1453
            capture_hidden_mode=(
1454
                CaptureHiddenMode.FULL
1455
                if self.return_hidden_states
1456
1457
1458
1459
1460
1461
1462
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
            ),
1464
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1465
1466
        )

1467
    def copy(self):
1468
        # Only contain fields that will be used by process_batch_result
1469
1470
        return ScheduleBatch(
            reqs=self.reqs,
1471
            model_config=self.model_config,
1472
            forward_mode=self.forward_mode,
1473
1474
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1475
            decoding_reqs=self.decoding_reqs,
1476
            spec_algorithm=self.spec_algorithm,
1477
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1478
1479
1480
1481
1482
1483
1484
1485
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1486

1487
@dataclasses.dataclass
1488
class ModelWorkerBatch:
1489
1490
    # The batch id
    bid: int
1491
1492
1493
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1494
    input_ids: torch.Tensor
1495
1496
1497
1498
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1499
    # The indices of output tokens in the token_to_kv_pool_allocator
1500
1501
    out_cache_loc: torch.Tensor

1502
1503
1504
    # The sum of all sequence lengths
    seq_lens_sum: int

1505
1506
1507
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1508
    token_ids_logprobs: Optional[List[List[int]]]
1509

Ke Bao's avatar
Ke Bao committed
1510
1511
    # For DP attention
    global_num_tokens: Optional[List[int]]
1512
    global_num_tokens_for_logprob: Optional[List[int]]
1513
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1514

1515
1516
1517
    # For decode
    decode_seq_lens: Optional[torch.Tensor]

1518
    # For extend
1519
    extend_num_tokens: Optional[int]
1520
1521
1522
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1523
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1524
1525

    # For multimodal
Mick's avatar
Mick committed
1526
    multimodal_inputs: Optional[List[MultimodalInputs]]
1527

1528
1529
1530
1531
1532
1533
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1534
1535
1536
1537
1538
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1539

Rin Intachuen's avatar
Rin Intachuen committed
1540
1541
1542
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1543
    # Speculative decoding
1544
    spec_algorithm: SpeculativeAlgorithm = None
1545
1546
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1547
    capture_hidden_mode: CaptureHiddenMode = None
1548

1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1567
1568
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1585
1586
1587
1588
1589
1590
1591
1592
1593


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )