schedule_batch.py 59.5 KB
Newer Older
1
2
from __future__ import annotations

Mick's avatar
Mick committed
3
4
from enum import Enum, auto

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
32
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
36
import logging
37
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
38

39
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
40
import torch
41
42
import triton
import triton.language as tl
43

Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.global_config import global_config
45
from sglang.srt.configs.model_config import ModelConfig
46
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
Byron Hsu's avatar
Byron Hsu committed
47
48
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
49
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
from sglang.srt.mem_cache.chunk_cache import ChunkCache
51
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
52
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
53
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
54
from sglang.srt.sampling.sampling_params import SamplingParams
55
from sglang.srt.server_args import ServerArgs
Mick's avatar
Mick committed
56
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
57

58
if TYPE_CHECKING:
59
60
61
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
62
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
63

64
65
# Put some global args for easy access
global_server_args_dict = {
66
67
68
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
69
    "disable_mla": ServerArgs.disable_mla,
70
    "torchao_config": ServerArgs.torchao_config,
71
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
72
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
73
    "enable_ep_moe": ServerArgs.enable_ep_moe,
74
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
75
    "deepep_mode": ServerArgs.deepep_mode,
76
    "device": ServerArgs.device,
77
78
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
79
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
lukec's avatar
lukec committed
80
    "enable_flashmla": ServerArgs.enable_flashmla,
81
    "disable_radix_cache": ServerArgs.disable_radix_cache,
82
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
83
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
84
85
    "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
    "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
86
87
}

Ying Sheng's avatar
Ying Sheng committed
88
89
90
logger = logging.getLogger(__name__)


91
92
93
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
94

95
    def to_json(self):
96
        raise NotImplementedError()
97
98
99


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
100
    def __init__(self, matched: Union[int, List[int]]):
101
102
103
        super().__init__()
        self.matched = matched

104
105
106
107
108
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
109
110


111
112
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
113
        super().__init__()
114
        self.matched = matched
115

116
117
118
119
120
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
121
122


123
124
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
125
        super().__init__()
126
        self.length = length
127

128
129
130
131
132
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
133
134
135


class FINISH_ABORT(BaseFinishReason):
136
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
137
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.message = message
139
140
        self.status_code = status_code
        self.err_type = err_type
141

142
143
144
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
145
            "message": self.message,
146
147
            "status_code": self.status_code,
            "err_type": self.err_type,
148
        }
149

Lianmin Zheng's avatar
Lianmin Zheng committed
150

Mick's avatar
Mick committed
151
152
153
154
155
156
157
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


158
@dataclasses.dataclass
Mick's avatar
Mick committed
159
160
161
162
class MultimodalDataItem:
    """
    A single multimodal data, from a single image/video/audio or other
    """
163

Mick's avatar
Mick committed
164
165
166
167
168
169
170
171
172
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
173
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    # the real data, pixel_values or audio_features
    # data: Union[List[torch.Tensor], List[np.array]]
    pixel_values: Union[torch.Tensor, np.array] = None
    image_grid_thws: Union[torch.Tensor, np.array] = None
    video_grid_thws: Union[torch.Tensor, np.array] = None

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

    audio_features: Union[torch.Tensor, np.array] = None
    audio_feature_lens: Optional[List[torch.Tensor]] = None

    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
        Set the pad value after first hashign the data
        """

        def hash_feature(f):
            if isinstance(f, list):
                return hash(tuple(flatten_nested_list(f)))
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
                return hash(arr_bytes)
            return hash(f)

        if self.is_audio():
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
        return (
            self.modality == Modality.AUDIO
        ) and not MultimodalDataItem.is_empty_list(self.audio_features)

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

    def is_video(self):
        return (
            self.modality == Modality.VIDEO
        ) and not MultimodalDataItem.is_empty_list(self.pixel_values)

    def validate(self):
        ...
        # TODO


@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
245
    image_pad_len: Optional[list] = None
246
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
247

Yineng Zhang's avatar
Yineng Zhang committed
248
    # QWen2-VL related
249
    mrope_position_delta: Optional[torch.Tensor] = None
250

Mick's avatar
Mick committed
251
    # image
252
253
254
255
256
    im_token_id: Optional[torch.Tensor] = None
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
257
258
259

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
260

Mick's avatar
Mick committed
261
262
263
264
    # audio
    audio_start_id: Optional[torch.Tensor] = None
    audio_end_id: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
265
    @staticmethod
266
    def from_dict(obj: dict):
Mick's avatar
Mick committed
267
        ret = MultimodalInputs(
Mick's avatar
Mick committed
268
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
269
        )
270

Mick's avatar
Mick committed
271
272
273
274
275
276
277
278
279
        assert isinstance(ret.mm_items, list)
        ret.mm_items = [
            item
            for item in ret.mm_items
            if item.is_audio() or item.is_image() or item.is_video()
        ]

        assert len(ret.mm_items) != 0

280
281
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
282
283
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
284
285
        for item in ret.mm_items:
            item.set_pad_value()
286
287
288

        optional_args = [
            "modalities",
289
            "im_token_id",
Mick's avatar
Mick committed
290
291
292
293
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
294
295
            "audio_start_id",
            "audio_end_id",
296
297
298
299
300
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
301
302
        return ret

Mick's avatar
Mick committed
303
304
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
305
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
306
307
308

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
309
310
311
312
        return any(item.is_audio() for item in self.mm_items)

    def collect_image_inputs(self) -> List[torch.Tensor]:
        return [item.pixel_values for item in self.mm_items if item.is_image()]
Mick's avatar
Mick committed
313
314

    def merge(self, other: MultimodalInputs):
315
316
317
        """
        merge image inputs when requests are being merged
        """
318

319
320
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
321
322
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
Mick's avatar
Mick committed
323

324
        # args needed to be merged
325
        optional_args = [
Mick's avatar
Mick committed
326
            "items",
327
            "image_offsets",
328
            "image_pad_len",
329
330
331
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
        ]
        for arg in optional_args:
332
333
334
335
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
        # other args would be kept intact
336

Liangsheng Yin's avatar
Liangsheng Yin committed
337

Lianmin Zheng's avatar
Lianmin Zheng committed
338
class Req:
339
    """The input and output status of a request."""
340

341
342
343
344
345
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
346
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
349
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        stream: bool = False,
351
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
352
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
353
        input_embeds: Optional[List[List[float]]] = None,
354
        session_id: Optional[str] = None,
355
        custom_logit_processor: Optional[str] = None,
356
        return_hidden_states: bool = False,
357
        eos_token_ids: Optional[Set[int]] = None,
358
    ):
359
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
360
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
361
        self.origin_input_text = origin_input_text
362
363
364
365
366
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
367
        self.origin_input_ids = origin_input_ids
368
369
370
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
371
        self.fill_ids = None
372
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
373
        self.input_embeds = input_embeds
374

Lianmin Zheng's avatar
Lianmin Zheng committed
375
        # Sampling info
376
377
378
379
380
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
381
        self.sampling_params = sampling_params
382
        self.custom_logit_processor = custom_logit_processor
383
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
384

385
        # Memory pool info
386
        self.req_pool_idx: Optional[int] = None
387

388
389
390
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
391
392
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
393
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
394
395
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
396
        self.stream = stream
397
        self.eos_token_ids = eos_token_ids
398

399
        # For incremental decoding
400
401
402
403
404
405
406
407
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
408
409
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
410
        self.decoded_text = ""
411

412
        # For multimodal inputs
Mick's avatar
Mick committed
413
        self.multimodal_inputs: Optional[MultimodalInputs] = None
414

415
        # Prefix info
416
        # The indices to kv cache for the shared prefix.
417
        self.prefix_indices = []
418
        # Number of tokens to run prefill.
419
        self.extend_input_len = 0
420
421
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
422
        self.last_node = None
423
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
424

425
426
427
428
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
429

430
431
432
        # For retraction
        self.is_retracted = False

433
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
434
        self.return_logprob = return_logprob
435
        # Start index to compute logprob from.
436
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
437
        self.top_logprobs_num = top_logprobs_num
438
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
441

442
        # Logprobs (return values)
443
444
445
446
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
447
448
449
450
451
452
453
454
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
458
459
460

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
461
462
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
465
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
466
467
468
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
469
        self.hidden_states: List[List[float]] = []
470

471
        # Embedding (return values)
472
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
473

474
        # Constrained decoding
475
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
476

477
        # The number of cached tokens that were already cached in the KV cache
478
        self.cached_tokens = 0
479
        self.already_computed = 0
480

481
482
483
484
485
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        # For disaggregation
        self.bootstrap_host: str = "0.0.0.0"
        self.bootstrap_room: Optional[int] = None
        self.disagg_kv_sender: Optional[KVSender] = None

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

504
505
506
507
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

508
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
509
510
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
511
        else:
Mick's avatar
Mick committed
512
            self.multimodal_inputs.merge(image_inputs)
513

514
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
515
        # Whether request reached finished condition
516
517
        return self.finished_reason is not None

518
519
520
521
522
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
523
        self.fill_ids = self.origin_input_ids + self.output_ids
524
        if tree_cache is not None:
525
            # tree cache is None if the prefix is not computed with tree cache.
526
527
528
529
530
531
532
533
534
535
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
536
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
537

538
    def adjust_max_prefix_ids(self):
539
540
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
541
542
543
544

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
545
546
547
548
549

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

550
        if self.return_logprob:
551
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
552

553
        max_prefix_len = max(max_prefix_len, 0)
554
        return self.fill_ids[:max_prefix_len]
555

Liangsheng Yin's avatar
Liangsheng Yin committed
556
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
557
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
558
559
560
561
562
563
564
565
566
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
567
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
568

569
    def check_finished(self):
570
        if self.finished():
571
572
            return

573
        if self.to_abort:
574
575
576
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
577
578
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
579
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
580
581
582
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
583
584
            return

585
        last_token_id = self.output_ids[-1]
586

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
604

605
        # Check stop strings
606
607
608
609
610
611
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
612
                if stop_str in tail_str or stop_str in self.decoded_text:
613
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
614
615
                    return

616
617
618
619
620
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
621
622
623
624
625
626
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
627
        self.already_computed = 0
628

Lianmin Zheng's avatar
Lianmin Zheng committed
629
    def __repr__(self):
630
        return (
631
632
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
633
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
634
635


636
637
638
bid = 0


639
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
640
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
641
    """Store all information of a batch on the scheduler."""
642

643
    # Request, memory pool, and cache
644
    reqs: List[Req]
645
    req_to_token_pool: ReqToTokenPool = None
646
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
647
    tree_cache: BasePrefixCache = None
648

649
    # Batch configs
650
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
651
    forward_mode: ForwardMode = None
652
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
657
658

    # Sampling info
659
    sampling_info: SamplingBatchInfo = None
660
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
661

662
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
663
    input_ids: torch.Tensor = None  # shape: [b], int64
664
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
665
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
666
    seq_lens: torch.Tensor = None  # shape: [b], int64
667
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
668
669
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
670

671
672
673
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
674
675
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
676
    global_num_tokens_for_logprob: Optional[List[int]] = None
677
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
678

679
    # For processing logprobs
680
    return_logprob: bool = False
681
    top_logprobs_nums: Optional[List[int]] = None
682
    token_ids_logprobs: Optional[List[List[int]]] = None
683

Lianmin Zheng's avatar
Lianmin Zheng committed
684
685
686
687
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

688
689
690
691
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
692
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
693
    extend_logprob_start_lens: List[int] = None
694
695
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
696

Lianmin Zheng's avatar
Lianmin Zheng committed
697
    # For encoder-decoder architectures
698
699
700
701
702
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

703
704
705
    # Stream
    has_stream: bool = False

706
707
    # Has grammar
    has_grammar: bool = False
708

709
    # Device
710
711
    device: str = "cuda"

712
    # Speculative decoding
713
    spec_algorithm: SpeculativeAlgorithm = None
714
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
715

716
717
718
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

719
720
721
    # Whether to return hidden states
    return_hidden_states: bool = False

722
    @classmethod
723
724
    def init_new(
        cls,
725
        reqs: List[Req],
726
        req_to_token_pool: ReqToTokenPool,
727
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
728
729
730
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
731
        spec_algorithm: SpeculativeAlgorithm,
732
        enable_custom_logit_processor: bool,
733
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
734
735
        return_logprob = any(req.return_logprob for req in reqs)

736
737
738
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
739
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
740
            tree_cache=tree_cache,
741
            model_config=model_config,
742
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
743
            return_logprob=return_logprob,
744
            has_stream=any(req.stream for req in reqs),
745
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
746
            device=req_to_token_pool.device,
747
            spec_algorithm=spec_algorithm,
748
            enable_custom_logit_processor=enable_custom_logit_processor,
749
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
750
751
        )

752
    def batch_size(self):
753
        return len(self.reqs)
754

Lianmin Zheng's avatar
Lianmin Zheng committed
755
756
757
    def is_empty(self):
        return len(self.reqs) == 0

758
    def alloc_req_slots(self, num_reqs: int):
759
760
761
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
762
763
764
765
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
766
767
768
            )
        return req_pool_indices

769
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
770
771
772
773
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

774
775
776
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

777
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
778
779
780
781
782
783
784
785
786
787
788
789
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

790
791
792
793
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
794
795
796
797
798
799
800

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
801
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
804
805
806
807
808
809
810
811
812
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
813

814
815
816
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
817
818
819
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
820
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
821
822
823
824
825
826
827
828
829
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
830
831
832
833
834

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
835
836
837
838
839

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
840
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
841
    ):
842
843
844
845
846
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
847
848
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
849
                )
850

851
852
853
854
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
855
856
857
858
859
860
861
862
863
864
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
865
866
867
868
869

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
870

871
872
873
874
875
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
876
            im = req.multimodal_inputs
877
878
879
880
881
882
883
884
885
886
887
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

888
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
889
890
891
892
893
894
895
896
897
898
899
900
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
901
                # NOTE: the encoder part should be considered as a whole
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
919
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
920
921
            self.device, non_blocking=True
        )
922
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
923
924
925
926
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
927
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
928
929
930
931
932
933
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
934
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
935
936
937
938
939
940
941
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

942
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
943
944
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
945
        # Allocate req slots
946
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
947
948
949
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
950
        reqs = self.reqs
951
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
952
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
953
954
955
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
956

Lianmin Zheng's avatar
Lianmin Zheng committed
957
958
959
960
961
962
963
964
965
966
967
968
969
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
970

Lianmin Zheng's avatar
Lianmin Zheng committed
971
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
972
        input_embeds = []
973
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
974

Lianmin Zheng's avatar
Lianmin Zheng committed
975
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
976
            req.req_pool_idx = req_pool_indices[i]
977
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
978

979
            if pre_len > 0:
980
981
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
982
                )
983

Rin Intachuen's avatar
Rin Intachuen committed
984
985
986
987
988
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

989
990
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
991
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
992

993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1047

Lianmin Zheng's avatar
Lianmin Zheng committed
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1061
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1062
1063
1064
1065
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1066
1067
1068
1069
1070
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1071
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1072

1073
1074
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1075
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1076

1077
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1078
1079
1080
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1081
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1082

1083
        # Write to req_to_token_pool
1084
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1085
1086
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1087
1088
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1089
1090
1091
1092
1093
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1094
1095
1096
1097
1098
1099
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1100
1101
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1102
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
                pt += extend_lens[i]
1104

1105
1106
1107
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1108
        # Build sampling info
1109
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1110
1111
            self,
            self.model_config.vocab_size,
1112
        )
1113

1114
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1115
        self.forward_mode = ForwardMode.MIXED
1116
        running_bs = running_batch.batch_size()
1117
1118
1119
1120
1121

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1122
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1123
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1124

1125
        self.merge_batch(running_batch)
1126
1127
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1128

1129
1130
1131
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1132
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1133
        self.prefix_lens.extend(
1134
            [
1135
                len(r.origin_input_ids) + len(r.output_ids) + delta
1136
1137
1138
                for r in running_batch.reqs
            ]
        )
1139
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1140
1141
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1142
        self.extend_logprob_start_lens.extend([0] * running_bs)
1143

1144
1145
1146
1147
1148
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
        return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1149

1150
1151
1152
1153
1154
1155
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1156

1157
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1158
1159
            return True

1160
1161
1162
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1163

1164
    def retract_decode(self, server_args: ServerArgs):
1165
        """Retract the decoding requests when there is not enough memory."""
1166
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1167
1168

        # TODO(lsyin): improve retraction policy for radix cache
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1194

Lianmin Zheng's avatar
Lianmin Zheng committed
1195
1196
1197
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1198
        while (
1199
            self.token_to_kv_pool_allocator.available_size()
1200
            < get_required_tokens(len(sorted_indices))
1201
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1202
1203
1204
1205
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1206
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1207
1208
1209
                ), "No space left for only one request"
                break

1210
            first_iter = False
1211
1212
1213
1214
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1215
1216
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1217
1218
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1219
                ]
1220
                self.token_to_kv_pool_allocator.free(token_indices)
1221
                self.req_to_token_pool.free(req.req_pool_idx)
1222
1223
            else:
                # TODO: apply more fine-grained retraction
1224
                last_uncached_pos = (
1225
1226
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1227
1228
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1229
                ]
1230
                self.token_to_kv_pool_allocator.free(token_indices)
1231
                self.req_to_token_pool.free(req.req_pool_idx)
1232
1233
1234
1235
1236
1237
1238

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1239
                    - self.token_to_kv_pool_allocator.available_size()
1240
1241
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1242
                self.tree_cache.evict(residual_size)
1243

1244
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1245

1246
        self.filter_batch(keep_indices=sorted_indices)
1247

Liangsheng Yin's avatar
Liangsheng Yin committed
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1258

1259
1260
1261
1262
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1263
1264
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1265
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1266
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1267
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1268
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1269
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1270
        self.extend_num_tokens = 0
1271
1272
1273
1274
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1275

1276
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1277
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1278
1279
        bs = len(self.reqs)

1280
        if self.spec_algorithm.is_eagle():
1281
1282
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1283
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1284

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1308
        # Update fields
1309
1310
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1311

1312
1313
1314
1315
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
            locs = self.seq_lens.clone()
1317

1318
        if self.enable_overlap:
1319
1320
1321
1322
1323
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1324
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1325

Lianmin Zheng's avatar
Lianmin Zheng committed
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1341
1342
    def filter_batch(
        self,
1343
        chunked_req_to_exclude: Optional[Req] = None,
1344
1345
1346
1347
1348
1349
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1350
1351
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1352
1353
1354
            ]

        if keep_indices is None or len(keep_indices) == 0:
1355
1356
1357
1358
            # Filter out all requests
            self.reqs = []
            return

1359
        if len(keep_indices) == len(self.reqs):
1360
1361
1362
            # No need to filter
            return

1363
1364
1365
1366
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1367
        if self.model_config.is_encoder_decoder:
1368
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1369
1370
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1371
        self.reqs = [self.reqs[i] for i in keep_indices]
1372
1373
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1374
        self.out_cache_loc = None
1375
        self.seq_lens_sum = self.seq_lens.sum().item()
1376
        self.output_ids = self.output_ids[keep_indices_device]
1377
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1378
        if self.return_logprob:
1379
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1380
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1381
1382
        else:
            self.top_logprobs_nums = None
1383
            self.token_ids_logprobs = None
1384

1385
        self.has_stream = any(req.stream for req in self.reqs)
1386
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1387

1388
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1389
        if self.spec_info:
1390
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1391

1392
    def merge_batch(self, other: "ScheduleBatch"):
1393
1394
1395
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1396
        self.sampling_info.merge_batch(other.sampling_info)
1397

1398
1399
1400
1401
1402
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

1403
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1404
1405
            [self.req_pool_indices, other.req_pool_indices]
        )
1406
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1407
        self.out_cache_loc = None
1408
        self.seq_lens_sum += other.seq_lens_sum
1409
        if self.output_ids is not None:
1410
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1411
1412
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1413
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1414
1415
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1416
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1417
1418
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1419
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1420
        self.reqs.extend(other.reqs)
1421

1422
1423
1424
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1425
        self.return_hidden_states |= other.return_hidden_states
1426

1427
1428
1429
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1430
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1431
        if self.forward_mode.is_decode_or_idle():
1432
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1433
1434
1435
1436
1437
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1438
1439
        # Create seq_lens_cpu when needed
        if (
1440
            global_server_args_dict["enable_flashinfer_mla"]
1441
1442
1443
1444
1445
1446
1447
            or global_server_args_dict["enable_flashmla"]
            or global_server_args_dict["attention_backend"] == "fa3"
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1448
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1449
1450
1451
1452
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1453

1454
1455
        global bid
        bid += 1
1456
        return ModelWorkerBatch(
1457
            bid=bid,
1458
1459
1460
1461
1462
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1463
            seq_lens_sum=self.seq_lens_sum,
1464
1465
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1466
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1467
            global_num_tokens=self.global_num_tokens,
1468
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1469
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1470
            seq_lens_cpu=seq_lens_cpu,
1471
            extend_num_tokens=self.extend_num_tokens,
1472
1473
1474
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
Mick's avatar
Mick committed
1475
            multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1476
1477
1478
1479
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1480
            lora_paths=[req.lora_path for req in self.reqs],
1481
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1482
            input_embeds=self.input_embeds,
1483
1484
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1485
            capture_hidden_mode=(
1486
                CaptureHiddenMode.FULL
1487
                if self.return_hidden_states
1488
1489
1490
1491
1492
1493
1494
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1495
            ),
1496
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1497
1498
        )

1499
    def copy(self):
1500
        # Only contain fields that will be used by process_batch_result
1501
1502
        return ScheduleBatch(
            reqs=self.reqs,
1503
            model_config=self.model_config,
1504
            forward_mode=self.forward_mode,
1505
1506
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1507
            decoding_reqs=self.decoding_reqs,
1508
            spec_algorithm=self.spec_algorithm,
1509
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1510
1511
1512
1513
1514
1515
1516
1517
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1518

1519
@dataclasses.dataclass
1520
class ModelWorkerBatch:
1521
1522
    # The batch id
    bid: int
1523
1524
1525
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1526
    input_ids: torch.Tensor
1527
1528
1529
1530
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1531
    seq_lens_cpu: Optional[torch.Tensor]
1532
    # The indices of output tokens in the token_to_kv_pool_allocator
1533
1534
    out_cache_loc: torch.Tensor

1535
1536
1537
    # The sum of all sequence lengths
    seq_lens_sum: int

1538
1539
1540
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1541
    token_ids_logprobs: Optional[List[List[int]]]
1542

Ke Bao's avatar
Ke Bao committed
1543
1544
    # For DP attention
    global_num_tokens: Optional[List[int]]
1545
    global_num_tokens_for_logprob: Optional[List[int]]
1546
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1547

1548
    # For extend
1549
    extend_num_tokens: Optional[int]
1550
1551
1552
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1553
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1554
1555

    # For multimodal
Mick's avatar
Mick committed
1556
    multimodal_inputs: Optional[List[MultimodalInputs]]
1557

1558
1559
1560
1561
1562
1563
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1564
1565
1566
1567
1568
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1569

Rin Intachuen's avatar
Rin Intachuen committed
1570
1571
1572
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1573
    # Speculative decoding
1574
    spec_algorithm: SpeculativeAlgorithm = None
1575
1576
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1577
    capture_hidden_mode: CaptureHiddenMode = None
1578

1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1597
1598
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1615
1616
1617
1618
1619
1620
1621
1622
1623


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )