schedule_batch.py 79.1 KB
Newer Older
1
2
from __future__ import annotations

3
4
import enum

5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
18
19
20
21
22
23
24
25
26
27
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
29
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
30
31
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
34
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import copy
37
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
38
import logging
39
import threading
40
import time
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from enum import Enum, auto
42
from http import HTTPStatus
43
from itertools import chain
Yi Zhang's avatar
Yi Zhang committed
44
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
47
import torch
48
49
import triton
import triton.language as tl
50

Liangsheng Yin's avatar
Liangsheng Yin committed
51
from sglang.global_config import global_config
52
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
53
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
54
55
56
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
57
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Hanming Lu's avatar
Hanming Lu committed
58
59
60
61
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
62
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
63
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
Yi Zhang's avatar
Yi Zhang committed
64
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
65
from sglang.srt.mem_cache.radix_cache import RadixKey
Hanming Lu's avatar
Hanming Lu committed
66
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
67
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
69
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
70
from sglang.srt.sampling.sampling_params import SamplingParams
71
from sglang.srt.server_args import ServerArgs
72
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
73

74
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
75
    from sglang.srt.configs.model_config import ModelConfig
76
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
77
    from sglang.srt.speculative.ngram_utils import NgramVerifyInput
78
79
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
80
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
81

82
83
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
84
    "mm_attention_backend",
85
86
87
88
89
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
90
    "disable_flashinfer_cutlass_moe_fp4_allgather",
91
92
    "disable_radix_cache",
    "enable_dp_lm_head",
93
    "enable_fp32_lm_head",
94
    "flashinfer_mxfp4_moe_precision",
95
    "enable_flashinfer_allreduce_fusion",
96
97
98
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
    "ep_num_redundant_experts",
99
100
101
102
103
104
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
105
    "speculative_accept_threshold_acc",
106
    "speculative_attention_mode",
107
108
    "torchao_config",
    "triton_attention_reduce_in_fp32",
109
    "num_reserved_decode_tokens",
110
    "weight_loader_disable_mmap",
111
    "enable_multimodal",
112
    "enable_symm_mem",
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    "enable_custom_logit_processor",
114
    "disaggregation_mode",
115
    "enable_deterministic_inference",
116
117
]

118
# Put some global args for easy access
119
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
120

Ying Sheng's avatar
Ying Sheng committed
121
122
123
logger = logging.getLogger(__name__)


124
125
126
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
127

128
    def to_json(self):
129
        raise NotImplementedError()
130
131
132


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
133
    def __init__(self, matched: Union[int, List[int]]):
134
135
136
        super().__init__()
        self.matched = matched

137
138
139
140
141
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
142
143


144
145
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
146
        super().__init__()
147
        self.matched = matched
148

149
150
151
152
153
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
154
155


156
157
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
158
        super().__init__()
159
        self.length = length
160

161
162
163
164
165
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
166
167
168


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
169
    def __init__(self, message=None, status_code=None, err_type=None):
170
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
171
        self.message = message or "Aborted"
172
173
        self.status_code = status_code
        self.err_type = err_type
174

175
176
177
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
178
            "message": self.message,
179
180
            "status_code": self.status_code,
            "err_type": self.err_type,
181
        }
182

Lianmin Zheng's avatar
Lianmin Zheng committed
183

Mick's avatar
Mick committed
184
185
186
187
188
189
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

190
191
192
193
194
195
196
197
198
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

199
200
201
202
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
203

204
@dataclasses.dataclass
Mick's avatar
Mick committed
205
206
class MultimodalDataItem:
    """
207
208
209
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
210

211
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
212
    """
213

Mick's avatar
Mick committed
214
215
216
    modality: Modality
    hash: int = None
    pad_value: int = None
217
    offsets: Optional[list] = None
Mick's avatar
Mick committed
218

219
220
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
221
222
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
238

239
240
241
242
243
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
244

245
246
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
247

Mick's avatar
Mick committed
248
249
250
251
252
253
254
255
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
256
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
257
        """
258
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
259

260
        if self.hash is None:
261
262
            if self.feature is not None:
                hashed_feature = self.feature
263
            else:
264
                hashed_feature = self.precomputed_embeddings
265
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
266
267
268
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

269
270
271
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
272
    def is_audio(self):
273
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
274
275

    def is_image(self):
276
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
277
278

    def is_video(self):
279
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
280

281
282
283
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
284
285
286
287
    def validate(self):
        ...
        # TODO

288
289
290
291
292
293
294
295
296
297
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

298
    def merge(self, other):
299
        self.feature += other.feature
300
        self.offsets += other.offsets
301
302
303
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
304
305
306
307
308
309
310

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
311
    image_pad_len: Optional[list] = None
312
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
313

Mick's avatar
Mick committed
314
    # image
Mick's avatar
Mick committed
315
    im_token_id: Optional[int] = None
316
317
318
319
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
320
321
322

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
323

Mick's avatar
Mick committed
324
    # audio
325
326
327
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
328

329
330
331
332
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
333
    @staticmethod
334
    def from_dict(obj: dict):
Mick's avatar
Mick committed
335
        ret = MultimodalInputs(
Mick's avatar
Mick committed
336
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
337
        )
338

Mick's avatar
Mick committed
339
        assert isinstance(ret.mm_items, list)
340
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
341
342
        for item in ret.mm_items:
            item.set_pad_value()
343
344

        optional_args = [
345
346
            "mrope_positions",
            "mrope_position_delta",
347
            "im_token_id",
Mick's avatar
Mick committed
348
349
            "im_start_id",
            "im_end_id",
350
            "video_token_id",
Mick's avatar
Mick committed
351
352
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
353
354
            "audio_start_id",
            "audio_end_id",
355
            "audio_token_id",
356
357
358
359
360
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
361
362
        return ret

Mick's avatar
Mick committed
363
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
364
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
365

366
367
368
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
369
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
370
371
        return any(item.is_audio() for item in self.mm_items)

372
373
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
374
375

    def merge(self, other: MultimodalInputs):
376
377
378
        """
        merge image inputs when requests are being merged
        """
379

380
        # args needed to be merged
381
        optional_args = [
Mick's avatar
Mick committed
382
            "mm_items",
383
            "image_pad_len",
384
385
        ]
        for arg in optional_args:
386
387
388
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
389
390
391
392
393
394
395
396
397
398

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

399
400
401
402
403
404
405
406
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
407
408
409
410
411
412

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
413
        # other args would be kept intact
414

Liangsheng Yin's avatar
Liangsheng Yin committed
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
class RequestStage(str, enum.Enum):
    # prefill
    PREFILL_WAITING = "prefill_waiting"

    # disaggregation prefill
    PREFILL_PREPARE = "prefill_prepare"
    PREFILL_BOOTSTRAP = "prefill_bootstrap"
    PREFILL_FORWARD = "prefill_forward"
    PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"

    # disaggregation decode
    DECODE_PREPARE = "decode_prepare"
    DECODE_BOOTSTRAP = "decode_bootstrap"
    DECODE_WAITING = "decode_waiting"
    DECODE_TRANSFERRED = "decode_transferred"


Lianmin Zheng's avatar
Lianmin Zheng committed
433
class Req:
434
    """The input and output status of a request."""
435

436
437
438
439
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
440
        origin_input_ids: List[int],
441
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
444
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
445
        stream: bool = False,
446
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
447
        lora_id: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
448
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
449
        token_type_ids: List[int] = None,
450
        session_id: Optional[str] = None,
451
        custom_logit_processor: Optional[str] = None,
452
        return_hidden_states: bool = False,
453
        eos_token_ids: Optional[Set[int]] = None,
454
        bootstrap_host: Optional[str] = None,
455
        bootstrap_port: Optional[int] = None,
456
        bootstrap_room: Optional[int] = None,
457
        data_parallel_rank: Optional[int] = None,
458
        vocab_size: Optional[int] = None,
459
        priority: Optional[int] = None,
460
        metrics_collector: Optional[SchedulerMetricsCollector] = None,
461
        extra_key: Optional[str] = None,
462
    ):
463
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
465
        self.origin_input_text = origin_input_text
466
467
468
469
470
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
471
        self.origin_input_ids = origin_input_ids
472
473
474
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
475
        self.fill_ids = []
476
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
477
        self.input_embeds = input_embeds
478

woodx's avatar
woodx committed
479
480
481
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
482
483
484
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
485
        # Sampling info
486
487
488
489
490
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
491
        self.sampling_params = sampling_params
492
        self.custom_logit_processor = custom_logit_processor
493
        self.return_hidden_states = return_hidden_states
494

495
        # extra key for classifying the request (e.g. cache_salt)
496
497
498
499
500
501
        if lora_id is not None:
            extra_key = (
                extra_key or ""
            ) + lora_id  # lora_id is concatenated to the extra key

        self.extra_key = extra_key
502
        self.lora_id = lora_id
Liangsheng Yin's avatar
Liangsheng Yin committed
503

504
        # Memory pool info
505
        self.req_pool_idx: Optional[int] = None
506

507
508
509
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
        # Whether this request has finished output
        self.finished_output = None
512
513
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
514
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
515
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
516
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
517
        self.stream = stream
518
        self.eos_token_ids = eos_token_ids
519
        self.vocab_size = vocab_size
520
        self.priority = priority
521

522
        # For incremental decoding
523
524
525
526
527
528
529
530
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
531
532
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
533
        self.decoded_text = ""
534

535
        # For multimodal inputs
Mick's avatar
Mick committed
536
        self.multimodal_inputs: Optional[MultimodalInputs] = None
537

538
        # Prefix info
539
        # The indices to kv cache for the shared prefix.
540
        self.prefix_indices: torch.Tensor = []
541
        # Number of tokens to run prefill.
542
        self.extend_input_len = 0
543
544
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
545
546
547
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
548
549
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
550

551
552
553
554
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
555

556
557
558
        # For retraction
        self.is_retracted = False

559
560
561
562
563
564
565
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

566
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
567
        self.return_logprob = return_logprob
568
        # Start index to compute logprob from.
569
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
570
        self.top_logprobs_num = top_logprobs_num
571
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
572
573
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
574

575
        # Logprobs (return values)
576
577
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
578
579
580
581
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
582
583
584
585
586
587
588
589
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
590
591

        if return_logprob:
592
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
593
594
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
595
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
596
597
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
598
599
600
601
            # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
            self.output_token_ids_logprobs_val: List[
                Union[List[float], torch.Tensor]
            ] = []
602
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
603
604
605
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
606
607
608
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
609
        self.hidden_states: List[List[float]] = []
610
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
611
612
        self.output_topk_p = None
        self.output_topk_index = None
613

614
        # Embedding (return values)
615
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
616

617
        # Constrained decoding
618
        self.grammar: Optional[BaseGrammarObject] = None
619
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
620

621
        # The number of cached tokens that were already cached in the KV cache
622
        self.cached_tokens = 0
623
        self.already_computed = 0
624

625
626
627
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
628
629

        # For metrics
630
        self.metrics_collector = metrics_collector
631
632
633
634
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
635
        self.last_tic = time.monotonic()
636

Byron Hsu's avatar
Byron Hsu committed
637
        # For disaggregation
638
        self.bootstrap_host: str = bootstrap_host
639
        self.bootstrap_port: Optional[int] = bootstrap_port
640
        self.bootstrap_room: Optional[int] = bootstrap_room
641
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
642

643
644
645
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
646
647
648
649
650
651
652
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

653
654
655
656
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
657
        self.metadata_buffer_index: int = -1
658

659
660
661
662
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

663
664
665
666
667
    @property
    def is_prefill_only(self) -> bool:
        """Check if this request is prefill-only (no token generation needed)."""
        return self.sampling_params.max_new_tokens == 0

668
669
670
671
672
673
674
675
676
677
    def add_latency(self, stage: RequestStage):
        if self.metrics_collector is None:
            return
        assert stage.name in RequestStage.__members__, f"{stage=} is invalid"
        now = time.monotonic()
        self.metrics_collector.observe_request_latency_seconds(
            stage.value, now - self.last_tic
        )
        self.last_tic = now

678
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
679
680
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
681
        else:
Mick's avatar
Mick committed
682
            self.multimodal_inputs.merge(image_inputs)
683

684
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
685
        # Whether request reached finished condition
686
687
        return self.finished_reason is not None

688
689
690
691
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
    ):
692
        self.fill_ids = self.origin_input_ids + self.output_ids
693
        if tree_cache is not None:
694
695
696
697
698
699
700
701
702
703
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
                key=RadixKey(
                    token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
                ),
            )
704
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
705

706
    def adjust_max_prefix_ids(self):
707
708
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
709
710
711
712

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
713
714
715
716
717

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

718
        if self.return_logprob:
719
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
720

721
        max_prefix_len = max(max_prefix_len, 0)
722
        return self.fill_ids[:max_prefix_len]
723

Liangsheng Yin's avatar
Liangsheng Yin committed
724
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
725
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
726
727
728
729
730
731
732
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )
733
734
735
736
737
738
739
            self.surr_and_decode_ids = (
                self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
            )
            self.cur_decode_ids_len = len(self.output_ids)
        else:
            self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
            self.cur_decode_ids_len = len(self.output_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
740

741
        return self.surr_and_decode_ids, self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
742

743
    def check_finished(self):
744
        if self.finished():
745
746
            return

747
        if self.to_abort:
748
749
750
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
751
752
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
753
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
754
755
756
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
757
758
            return

759
760
761
762
763
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

764
        last_token_id = self.output_ids[-1]
765

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
783

784
785
786
787
788
789
790
791
        if last_token_id > self.vocab_size or last_token_id < 0:
            if self.sampling_params.stop_token_ids:
                self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
            if self.eos_token_ids:
                self.output_ids[-1] = next(iter(self.eos_token_ids))
            self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
            return

792
        # Check stop strings
793
794
795
796
797
798
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
799
                if stop_str in tail_str or stop_str in self.decoded_text:
800
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
801
802
                    return

803
804
805
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
806
        self.swa_uuid_for_lock = None
807
808
        self.extend_input_len = 0
        self.is_retracted = True
809
810
811
812
813
814
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
815
        self.already_computed = 0
816

Lianmin Zheng's avatar
Lianmin Zheng committed
817
818
819
820
821
822
823
824
825
826
827
828
829
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

830
831
832
833
834
835
836
837
838
839
840
841
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

842
843
844
845
846
847
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
848
        self.return_logprob = False
849
850
851
852
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
853
    def __repr__(self):
854
        return (
855
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
856
857
858
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
859
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
860
861


Lianmin Zheng's avatar
Lianmin Zheng committed
862
# Batch id
863
864
865
bid = 0


866
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
867
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
868
    """Store all information of a batch on the scheduler."""
869

870
    # Request, memory pool, and cache
871
    reqs: List[Req]
872
    req_to_token_pool: ReqToTokenPool = None
873
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
874
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
875
    is_hybrid: bool = False
876

877
    # Batch configs
878
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
879
    forward_mode: ForwardMode = None
880
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
881
882
883
884
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
885

886
887
888
    # Events
    launch_done: Optional[threading.Event] = None

889
890
891
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

892
    # Sampling info
893
    sampling_info: SamplingBatchInfo = None
894
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
895

896
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
897
    input_ids: torch.Tensor = None  # shape: [b], int64
898
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
899
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
900
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
901
    seq_lens: torch.Tensor = None  # shape: [b], int64
902
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
903
904
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
905

906
907
908
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

909
910
    # The sum of all sequence lengths
    seq_lens_sum: int = None
911
912
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: torch.Tensor = None  # shape: [b], int32
913

Ke Bao's avatar
Ke Bao committed
914
915
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
916
    global_num_tokens_for_logprob: Optional[List[int]] = None
917
    is_extend_in_batch: bool = False
918
    can_run_dp_cuda_graph: bool = False
919
920
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
921

922
    # For processing logprobs
923
    return_logprob: bool = False
924
    top_logprobs_nums: Optional[List[int]] = None
925
    token_ids_logprobs: Optional[List[List[int]]] = None
926

Lianmin Zheng's avatar
Lianmin Zheng committed
927
928
929
930
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

931
932
933
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
934
    extend_num_tokens: Optional[int] = None
935
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
936
    extend_logprob_start_lens: List[int] = None
937
938
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
939

Lianmin Zheng's avatar
Lianmin Zheng committed
940
    # For encoder-decoder architectures
941
942
943
944
945
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

946
947
948
    # Stream
    has_stream: bool = False

949
950
    # Has grammar
    has_grammar: bool = False
951

952
    # Device
953
954
    device: str = "cuda"

955
    # Speculative decoding
956
    spec_algorithm: SpeculativeAlgorithm = None
957
958
959
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
        None
    )
960

961
962
963
    # Whether to return hidden states
    return_hidden_states: bool = False

964
965
966
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

967
    # hicache pointer for synchronizing data loading from CPU to GPU
968
    hicache_consumer_index: int = -1
969

970
    @classmethod
971
972
    def init_new(
        cls,
973
        reqs: List[Req],
974
        req_to_token_pool: ReqToTokenPool,
975
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
976
977
978
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
979
        spec_algorithm: SpeculativeAlgorithm,
980
        chunked_req: Optional[Req] = None,
981
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
982
983
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
984
985
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
986
987
988
989
            assert (
                tree_cache is None
                or isinstance(tree_cache, SWARadixCache)
                or isinstance(tree_cache, SWAChunkCache)
Hanming Lu's avatar
Hanming Lu committed
990
991
992
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

993
994
995
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
996
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
997
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
998
            is_hybrid=is_hybrid,
999
            model_config=model_config,
1000
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
            return_logprob=return_logprob,
1002
            has_stream=any(req.stream for req in reqs),
1003
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
1004
            device=req_to_token_pool.device,
1005
            spec_algorithm=spec_algorithm,
1006
            return_hidden_states=any(req.return_hidden_states for req in reqs),
1007
            is_prefill_only=all(req.is_prefill_only for req in reqs),
1008
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
1009
1010
        )

1011
    def batch_size(self):
1012
        return len(self.reqs)
1013

Lianmin Zheng's avatar
Lianmin Zheng committed
1014
1015
1016
    def is_empty(self):
        return len(self.reqs) == 0

Yi Zhang's avatar
Yi Zhang committed
1017
1018
1019
1020
1021
    def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
        if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
            req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
        else:
            req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
1022
1023
        if req_pool_indices is None:
            raise RuntimeError(
1024
1025
1026
1027
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
1028
1029
1030
            )
        return req_pool_indices

1031
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Hanming Lu's avatar
Hanming Lu committed
1032
        self._evict_tree_cache_if_needed(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1033

1034
1035
1036
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

1037
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1038
1039
1040
1041
1042
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1043
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1044
1045
1046
1047
1048
1049
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

1050
1051
1052
1053
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1054
1055
1056
1057
1058
1059
1060

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
1061
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1062
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1063
        # Over estimate the number of tokens: assume each request needs a new page.
Hanming Lu's avatar
Hanming Lu committed
1064
1065
        num_tokens = (
            extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1066
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
Hanming Lu's avatar
Hanming Lu committed
1067
1068
        )
        self._evict_tree_cache_if_needed(num_tokens)
1069

1070
1071
1072
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
1073
1074
1075
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
1076
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1077
1078
1079
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1080
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
1083
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1084
1085
1086
1087
1088

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1089
1090
1091
1092
1093

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1094
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1096
        # Over estimate the number of tokens: assume each request needs a new page.
Hanming Lu's avatar
Hanming Lu committed
1097
1098
        num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        self._evict_tree_cache_if_needed(num_tokens)
1099

1100
1101
1102
1103
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
1105
1106
1107
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1108
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1109
1110
1111
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1112
1113
1114
1115
1116

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1117

1118
1119
1120
1121
1122
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1123
            im = req.multimodal_inputs
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1135
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1148
                # NOTE: the encoder part should be considered as a whole
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1166
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1167
1168
            self.device, non_blocking=True
        )
1169
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1170
1171
1172
1173
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1174
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1175
1176
1177
1178
1179
1180
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1181
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1182
1183
1184
1185
1186
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1187
1188
1189
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1190

1191
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1192
1193
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1194
        # Allocate req slots
1195
        bs = len(self.reqs)
Yi Zhang's avatar
Yi Zhang committed
1196
        req_pool_indices = self.alloc_req_slots(bs, self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1197
1198

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1199
        reqs = self.reqs
1200
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1201
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1202
        seq_lens = [len(r.fill_ids) for r in reqs]
1203
        orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1204
1205
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1206

woodx's avatar
woodx committed
1207
1208
1209
1210
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
1211
1212
1213
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1214
1215
1216
        input_ids_tensor = torch.tensor(
            list(chain.from_iterable(input_ids)), dtype=torch.int64
        ).to(self.device, non_blocking=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1217
1218
1219
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
1220
1221
1222
        orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1223
1224
1225
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
woodx's avatar
woodx committed
1226
1227
1228
1229
1230
1231
1232

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

Lianmin Zheng's avatar
Lianmin Zheng committed
1233
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1234

Lianmin Zheng's avatar
Lianmin Zheng committed
1235
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1236
        input_embeds = []
1237
        extend_input_logprob_token_ids = []
1238
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1239

Lianmin Zheng's avatar
Lianmin Zheng committed
1240
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1241
            req.req_pool_idx = req_pool_indices[i]
1242
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1243

1244
            if pre_len > 0:
1245
1246
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1247
                )
tarinkk's avatar
tarinkk committed
1248
                if isinstance(self.tree_cache, SWAChunkCache):
Hanming Lu's avatar
Hanming Lu committed
1249
                    self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1250
1251
                        req, pre_len, self.model_config.attention_chunk_size
                    )
1252

Rin Intachuen's avatar
Rin Intachuen committed
1253
1254
1255
1256
1257
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1258
1259
            multimodal_inputs.append(req.multimodal_inputs)

1260
1261
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1262
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1263

1264
            # Compute the relative logprob_start_len in an extend batch
1265
1266
1267
1268
1269
1270
1271
1272
            #
            # Key variables:
            # - logprob_start_len: Absolute position in full sequence where logprob computation begins
            # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
            # - extend_input_len: Number of tokens that need to be processed in this extend batch
            #   (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
            #    and prefix_indices are the cached/shared prefix tokens)
            #
1273
            if req.logprob_start_len >= pre_len:
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
                # Optimization for prefill-only requests: When we only need logprobs at
                # positions beyond the input sequence (to score next-token likelihood), skip all
                # input logprob computation during prefill since no generation will occur.
                if self.is_prefill_only and req.logprob_start_len == len(
                    req.origin_input_ids
                ):
                    # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
                    req.extend_logprob_start_len = req.extend_input_len
                else:
                    # Convert absolute logprob_start_len to relative extend_logprob_start_len
                    #
                    # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
                    # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
                    # This means: "compute logprobs from position 3 onwards in extend batch"
                    req.extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len,
                        req.extend_input_len,
                        req.seqlen - 1,
                    )
1293
            else:
1294
                # logprob_start_len is before the current extend batch, so start from beginning
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1341

Lianmin Zheng's avatar
Lianmin Zheng committed
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1355
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1356
1357
1358
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
1359
        self.orig_seq_lens = orig_seq_lens_tensor
Lianmin Zheng's avatar
Lianmin Zheng committed
1360
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1361
1362
1363
1364
1365
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1366
1367
1368
1369
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1370
                pixel_values = getattr(mm_item, "feature", None)
1371
                if isinstance(pixel_values, torch.Tensor):
1372
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1373
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1374
        self.token_type_ids = token_type_ids_tensor
1375
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1376

1377
1378
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1379
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1380

1381
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1382
1383
1384
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1385
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1386

1387
        # Write to req_to_token_pool
1388
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1389
1390
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1391
1392
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1393
1394
1395
1396
1397
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1398
1399
1400
1401
1402
1403
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1404
1405
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1406
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1407
                pt += extend_lens[i]
1408

1409
1410
1411
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1412
        # Build sampling info
1413
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1414
1415
            self,
            self.model_config.vocab_size,
1416
        )
1417

1418
1419
1420
1421
1422
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1423
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1424
        self.forward_mode = ForwardMode.MIXED
1425
        running_bs = running_batch.batch_size()
1426
1427
1428
1429
1430

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1431
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1432
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1433

1434
        self.merge_batch(running_batch)
1435
1436
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1437

1438
1439
1440
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1441
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1442
        self.prefix_lens.extend(
1443
            [
1444
                len(r.origin_input_ids) + len(r.output_ids) + delta
1445
1446
1447
                for r in running_batch.reqs
            ]
        )
1448
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1449
1450
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1451
        self.extend_logprob_start_lens.extend([0] * running_bs)
1452

1453
    def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1454
        page_size = self.token_to_kv_pool_allocator.page_size
1455
1456
1457
1458
1459
        requests = (
            self.reqs
            if selected_indices is None
            else [self.reqs[i] for i in selected_indices]
        )
1460
        if page_size == 1:
1461
            return len(requests)
1462
1463
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1464
        return (
1465
            sum(1 for req in requests if req.seqlen % page_size == 0)
pansicheng's avatar
pansicheng committed
1466
            if self.enable_overlap
1467
            else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
pansicheng's avatar
pansicheng committed
1468
        )
1469

1470
1471
1472
    def check_decode_mem(
        self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
    ):
Hanming Lu's avatar
Hanming Lu committed
1473
        num_tokens = (
1474
            self.new_page_count_next_decode(selected_indices)
1475
1476
1477
1478
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1479
1480
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1481

1482
    def retract_decode(self, server_args: ServerArgs):
1483
        """Retract the decoding requests when there is not enough memory."""
1484
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1485
1486

        # TODO(lsyin): improve retraction policy for radix cache
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1500
1501
1502
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
1503
1504
        while first_iter or (
            not self.check_decode_mem(selected_indices=sorted_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
1505
1506
1507
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1522
1523
                break

1524
            first_iter = False
1525
1526
1527
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)
1528
            self.release_req(idx, len(sorted_indices), server_args)
Liangsheng Yin's avatar
Liangsheng Yin committed
1529

1530
1531
1532
1533
1534
1535
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1536
        self.filter_batch(keep_indices=sorted_indices)
1537

Liangsheng Yin's avatar
Liangsheng Yin committed
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1548

1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
    def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
        req = self.reqs[idx]
        seq_lens_cpu = self.seq_lens.cpu().numpy()

        if server_args.disaggregation_mode == "decode":
            req.offload_kv_cache(
                self.req_to_token_pool, self.token_to_kv_pool_allocator
            )
        if isinstance(self.tree_cache, ChunkCache):
            # ChunkCache does not have eviction
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)
        else:
            # TODO: apply more fine-grained retraction
            last_uncached_pos = (
                len(req.prefix_indices) // server_args.page_size
            ) * server_args.page_size
            token_indices = self.req_to_token_pool.req_to_token[
                req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
            ]
            self.token_to_kv_pool_allocator.free(token_indices)
            self.req_to_token_pool.free(req.req_pool_idx)

            # release the last node
            if self.is_hybrid:
                self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
            else:
                self.tree_cache.dec_lock_ref(req.last_node)

            # NOTE(lsyin): we should use the newly evictable memory instantly.
            num_tokens = remaing_req_count * global_config.retract_decode_steps
            self._evict_tree_cache_if_needed(num_tokens)

        req.reset_for_retract()

1587
1588
1589
1590
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1591
1592
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1593
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1594
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1595
        self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1596
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1597
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1598
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1599
        self.extend_num_tokens = 0
1600
1601
1602
1603
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1604

1605
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1606
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1607
1608
        bs = len(self.reqs)

1609
1610
1611
        if (
            self.spec_algorithm.is_eagle()
            or self.spec_algorithm.is_standalone()
1612
            or self.spec_algorithm.is_ngram()
1613
        ):
1614
1615
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1616
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1617

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1641
        # Update fields
1642
1643
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1644

1645
1646
1647
1648
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1649
            locs = self.seq_lens.clone()
1650

1651
        if self.enable_overlap:
1652
1653
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
1654
            self.orig_seq_lens = self.orig_seq_lens + 1
1655
1656
1657
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1658
            self.orig_seq_lens.add_(1)
1659
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1660

tarinkk's avatar
tarinkk committed
1661
1662
1663
        # free memory
        if isinstance(self.tree_cache, SWAChunkCache):
            for req in self.reqs:
Hanming Lu's avatar
Hanming Lu committed
1664
                self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1665
1666
1667
                    req, req.seqlen - 1, self.model_config.attention_chunk_size
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1683
1684
    def filter_batch(
        self,
1685
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1686
1687
1688
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1689
1690
1691
1692
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1693
1694
1695
            keep_indices = [
                i
                for i in range(len(self.reqs))
1696
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1697
                and self.reqs[i] not in chunked_req_to_exclude
1698
1699
1700
            ]

        if keep_indices is None or len(keep_indices) == 0:
1701
1702
1703
1704
            # Filter out all requests
            self.reqs = []
            return

1705
        if len(keep_indices) == len(self.reqs):
1706
1707
1708
            # No need to filter
            return

1709
1710
1711
1712
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1713
        if self.model_config.is_encoder_decoder:
1714
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1715
1716
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1717
        self.reqs = [self.reqs[i] for i in keep_indices]
1718
1719
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1720
1721
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1722
        self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1723
        self.out_cache_loc = None
1724
        self.seq_lens_sum = self.seq_lens.sum().item()
1725
        self.output_ids = self.output_ids[keep_indices_device]
1726
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1727
        if self.return_logprob:
1728
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1729
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1730
1731
        else:
            self.top_logprobs_nums = None
1732
            self.token_ids_logprobs = None
1733

1734
        self.has_stream = any(req.stream for req in self.reqs)
1735
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1736

1737
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1738
        if self.spec_info:
1739
1740
1741
1742
1743
1744
1745
1746
            if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
                has_been_filtered = False
            else:
                has_been_filtered = True
            self.spec_info.filter_batch(
                new_indices=keep_indices_device,
                has_been_filtered=has_been_filtered,
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1747

1748
    def merge_batch(self, other: "ScheduleBatch"):
1749
1750
1751
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1752
        self.sampling_info.merge_batch(other.sampling_info)
1753

1754
1755
1756
1757
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1758
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1759
1760
            [self.req_pool_indices, other.req_pool_indices]
        )
1761
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1762
        self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1763
        self.out_cache_loc = None
1764
        self.seq_lens_sum += other.seq_lens_sum
1765
        if self.output_ids is not None:
1766
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1767
1768
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1769
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1770
1771
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1772
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1773
1774
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1775
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1776
        self.reqs.extend(other.reqs)
1777
1778
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1779

1780
1781
1782
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1783
        self.return_hidden_states |= other.return_hidden_states
1784

1785
1786
1787
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1788
1789
1790
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1791
        if self.forward_mode.is_decode_or_idle():
1792
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1793
1794
1795
1796
1797
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1798
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1799
1800
1801
1802
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1803

Lianmin Zheng's avatar
Lianmin Zheng committed
1804
1805
1806
1807
1808
1809
        seq_lens_cpu = (
            seq_lens_cpu_cache
            if seq_lens_cpu_cache is not None
            else self.seq_lens.cpu()
        )

1810
1811
        global bid
        bid += 1
1812
        return ModelWorkerBatch(
1813
            bid=bid,
1814
1815
1816
1817
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
1818
            orig_seq_lens=self.orig_seq_lens,
1819
            out_cache_loc=self.out_cache_loc,
1820
            seq_lens_cpu=seq_lens_cpu,
1821
            seq_lens_sum=self.seq_lens_sum,
1822
1823
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1824
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1825
            global_num_tokens=self.global_num_tokens,
1826
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1827
            is_extend_in_batch=self.is_extend_in_batch,
1828
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1829
1830
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1831
            extend_num_tokens=self.extend_num_tokens,
1832
1833
1834
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1835
            multimodal_inputs=self.multimodal_inputs,
1836
1837
1838
1839
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1840
            lora_ids=[req.lora_id for req in self.reqs],
1841
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1842
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1843
            token_type_ids=self.token_type_ids,
1844
1845
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1846
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1847
            capture_hidden_mode=(
1848
                CaptureHiddenMode.FULL
1849
                if self.return_hidden_states
1850
1851
1852
1853
1854
1855
1856
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1857
            ),
1858
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1859
            launch_done=self.launch_done,
1860
            is_prefill_only=self.is_prefill_only,
1861
1862
        )

1863
    def copy(self):
1864
        # Only contain fields that will be used by process_batch_result
1865
1866
        return ScheduleBatch(
            reqs=self.reqs,
1867
            model_config=self.model_config,
1868
            forward_mode=self.forward_mode,
1869
1870
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1871
            decoding_reqs=self.decoding_reqs,
1872
            spec_algorithm=self.spec_algorithm,
1873
1874
1875
1876
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1877
            is_prefill_only=self.is_prefill_only,
1878
1879
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1880
1881
    def _evict_tree_cache_if_needed(self, num_tokens: int):
        if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
Hanming Lu's avatar
Hanming Lu committed
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

    def _available_and_evictable_str(self) -> str:
        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
            full_evictable_size = self.tree_cache.full_evictable_size()
            swa_evictable_size = self.tree_cache.swa_evictable_size()
            return (
                f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
                f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
                f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
                f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
            )
        else:
            available_size = self.token_to_kv_pool_allocator.available_size()
            evictable_size = self.tree_cache.evictable_size()
            return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"

1924
1925
    def __str__(self):
        return (
1926
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1927
1928
1929
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1930

1931
@dataclasses.dataclass
1932
class ModelWorkerBatch:
1933
1934
    # The batch id
    bid: int
1935
1936
1937
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1938
    input_ids: torch.Tensor
1939
1940
1941
1942
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1943
    # The indices of output tokens in the token_to_kv_pool_allocator
1944
    out_cache_loc: torch.Tensor
1945
1946
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1947
1948
    seq_lens_sum: int

1949
1950
1951
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1952
    token_ids_logprobs: Optional[List[List[int]]]
1953

Ke Bao's avatar
Ke Bao committed
1954
1955
    # For DP attention
    global_num_tokens: Optional[List[int]]
1956
    global_num_tokens_for_logprob: Optional[List[int]]
1957
    is_extend_in_batch: bool
1958
    can_run_dp_cuda_graph: bool
1959
1960
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1961

1962
    # For extend
1963
    extend_num_tokens: Optional[int]
1964
1965
1966
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1967
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1968
1969

    # For multimodal
Mick's avatar
Mick committed
1970
    multimodal_inputs: Optional[List[MultimodalInputs]]
1971

1972
1973
1974
1975
1976
1977
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1978
    # For LoRA
1979
    lora_ids: Optional[List[str]]
1980
1981
1982

    # Sampling info
    sampling_info: SamplingBatchInfo
1983

1984
1985
1986
    # The original sequence lengths, Qwen-1M related
    orig_seq_lens: Optional[torch.Tensor] = None

Rin Intachuen's avatar
Rin Intachuen committed
1987
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1988
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1989

woodx's avatar
woodx committed
1990
1991
1992
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1993
    # Speculative decoding
1994
    spec_algorithm: SpeculativeAlgorithm = None
1995
1996
1997
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
        None
    )
1998
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1999
    capture_hidden_mode: CaptureHiddenMode = None
2000
    hicache_consumer_index: int = -1
2001

2002
2003
2004
    # Overlap event
    launch_done: Optional[threading.Event] = None

2005
2006
2007
    # Whether this batch is prefill-only (no token generation needed)
    is_prefill_only: bool = False

2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
2026
2027
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2044
2045


2046
2047
2048
2049
2050
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
2051
2052
2053
2054
    if (
        global_server_args_dict["attention_backend"] != "ascend"
        and global_server_args_dict["attention_backend"] != "torch_native"
    ):
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
2067
2068
2069
2070
2071
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result