schedule_batch.py 74.3 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
36
import logging
37
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from enum import Enum, auto
39
from http import HTTPStatus
40
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
43
import torch
44
45
import triton
import triton.language as tl
46

Liangsheng Yin's avatar
Liangsheng Yin committed
47
from sglang.global_config import global_config
48
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
49
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
50
51
52
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
53
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Hanming Lu's avatar
Hanming Lu committed
54
55
56
57
from sglang.srt.mem_cache.allocator import (
    BaseTokenToKVPoolAllocator,
    SWATokenToKVPoolAllocator,
)
58
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
tarinkk's avatar
tarinkk committed
59
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
60
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
Hanming Lu's avatar
Hanming Lu committed
61
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
62
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
63
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
64
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
65
from sglang.srt.sampling.sampling_params import SamplingParams
66
from sglang.srt.server_args import ServerArgs
67
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
68

69
if TYPE_CHECKING:
Cheng Wan's avatar
Cheng Wan committed
70
    from sglang.srt.configs.model_config import ModelConfig
71
72
73
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
74
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
75

76
77
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
78
    "mm_attention_backend",
79
80
81
82
83
84
85
86
87
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
    "disable_radix_cache",
    "enable_dp_attention",
    "enable_two_batch_overlap",
    "enable_dp_lm_head",
88
89
    "enable_deepep_moe",
    "deepep_mode",
90
    "enable_ep_moe",
91
    "enable_flashinfer_moe",
92
    "enable_flashinfer_allreduce_fusion",
93
94
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
95
    "deepep_config",
96
    "ep_num_redundant_experts",
97
98
99
100
101
102
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_single",
103
    "speculative_accept_threshold_acc",
104
105
    "torchao_config",
    "triton_attention_reduce_in_fp32",
106
    "num_reserved_decode_tokens",
107
    "weight_loader_disable_mmap",
Yuan Luo's avatar
Yuan Luo committed
108
    "enable_triton_kernel_moe",
109
    "enable_multimodal",
110
111
]

112
# Put some global args for easy access
113
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
114

Ying Sheng's avatar
Ying Sheng committed
115
116
117
logger = logging.getLogger(__name__)


118
119
120
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
121

122
    def to_json(self):
123
        raise NotImplementedError()
124
125
126


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
127
    def __init__(self, matched: Union[int, List[int]]):
128
129
130
        super().__init__()
        self.matched = matched

131
132
133
134
135
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
136
137


138
139
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
140
        super().__init__()
141
        self.matched = matched
142

143
144
145
146
147
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
148
149


150
151
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
152
        super().__init__()
153
        self.length = length
154

155
156
157
158
159
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
160
161
162


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
163
    def __init__(self, message=None, status_code=None, err_type=None):
164
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
165
        self.message = message or "Aborted"
166
167
        self.status_code = status_code
        self.err_type = err_type
168

169
170
171
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
172
            "message": self.message,
173
174
            "status_code": self.status_code,
            "err_type": self.err_type,
175
        }
176

Lianmin Zheng's avatar
Lianmin Zheng committed
177

Mick's avatar
Mick committed
178
179
180
181
182
183
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()

184
185
186
187
188
189
190
191
192
    @staticmethod
    def from_str(modality_str: str):
        try:
            return Modality[modality_str.upper()]
        except KeyError:
            raise ValueError(
                f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
            )

193
194
195
196
    @staticmethod
    def all():
        return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]

Mick's avatar
Mick committed
197

198
@dataclasses.dataclass
Mick's avatar
Mick committed
199
200
class MultimodalDataItem:
    """
201
202
203
    One MultimodalDataItem contains all inputs for one modality.
    For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
    One for images and one for audio.
204

205
    We put the common fields first and the model-specific fields in model_specific_data.
Mick's avatar
Mick committed
206
    """
207

Mick's avatar
Mick committed
208
209
210
    modality: Modality
    hash: int = None
    pad_value: int = None
211
    offsets: Optional[list] = None
Mick's avatar
Mick committed
212

213
214
    # the raw features returned by processor, e.g. pixel_values or audio_features
    feature: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
215
216
    # the precomputed embeddings, passed as final encoder embeddings
    # One and only one of the feature and precomputed_embeddings will be empty
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None

    # Model-specific data stored in a dictionary
    model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name: str):
        if (
            "model_specific_data" in self.__dict__
            and name in self.__dict__["model_specific_data"]
        ):
            return self.__dict__["model_specific_data"][name]
        else:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{name}'"
            )
Mick's avatar
Mick committed
232

233
234
235
236
237
    def __setitem__(self, key: str, value: Any):
        if key in self.__dict__:
            self.__dict__[key] = value
        else:
            self.model_specific_data[key] = value
238

239
240
    def set(self, key: str, value: Any):
        self.__setitem__(key, value)
241

Mick's avatar
Mick committed
242
243
244
245
246
247
248
249
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
250
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
251
        """
252
        from sglang.srt.managers.mm_utils import hash_feature
Mick's avatar
Mick committed
253

254
        if self.hash is None:
255
256
            if self.feature is not None:
                hashed_feature = self.feature
257
            else:
258
                hashed_feature = self.precomputed_embeddings
259
            self.hash = hash_feature(hashed_feature)
Mick's avatar
Mick committed
260
261
262
        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

263
264
265
    def is_modality(self, modality: Modality) -> bool:
        return self.modality == modality

Mick's avatar
Mick committed
266
    def is_audio(self):
267
        return self.modality == Modality.AUDIO
Mick's avatar
Mick committed
268
269

    def is_image(self):
270
        return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
Mick's avatar
Mick committed
271
272

    def is_video(self):
273
        return self.modality == Modality.VIDEO
Mick's avatar
Mick committed
274

275
276
277
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
278
279
280
281
    def validate(self):
        ...
        # TODO

282
283
284
285
286
287
288
289
290
291
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

292
    def merge(self, other):
293
        self.feature += other.feature
294
        self.offsets += other.offsets
295
296
297
        self.hash = hash((self.hash, other.hash))
        self.set_pad_value()

Mick's avatar
Mick committed
298
299
300
301
302
303
304

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
305
    image_pad_len: Optional[list] = None
306
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
307

Mick's avatar
Mick committed
308
    # image
Mick's avatar
Mick committed
309
    im_token_id: Optional[int] = None
310
311
312
313
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
314
315
316

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
317

Mick's avatar
Mick committed
318
    # audio
319
320
321
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
322

323
324
325
326
    # QWen2-VL related
    mrope_positions: Optional[torch.Tensor] = None
    mrope_position_delta: Optional[torch.Tensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
327
    @staticmethod
328
    def from_dict(obj: dict):
Mick's avatar
Mick committed
329
        ret = MultimodalInputs(
Mick's avatar
Mick committed
330
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
331
        )
332

Mick's avatar
Mick committed
333
        assert isinstance(ret.mm_items, list)
334
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
335
336
        for item in ret.mm_items:
            item.set_pad_value()
337
338

        optional_args = [
339
340
            "mrope_positions",
            "mrope_position_delta",
341
            "im_token_id",
Mick's avatar
Mick committed
342
343
            "im_start_id",
            "im_end_id",
344
            "video_token_id",
Mick's avatar
Mick committed
345
346
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
347
348
            "audio_start_id",
            "audio_end_id",
349
            "audio_token_id",
350
351
352
353
354
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
355
356
        return ret

Mick's avatar
Mick committed
357
    def contains_image_inputs(self) -> bool:
Mick's avatar
Mick committed
358
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
359

360
361
362
    def contains_video_inputs(self) -> bool:
        return any(item.is_video() for item in self.mm_items)

Mick's avatar
Mick committed
363
    def contains_audio_inputs(self) -> bool:
Mick's avatar
Mick committed
364
365
        return any(item.is_audio() for item in self.mm_items)

366
367
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
368
369

    def merge(self, other: MultimodalInputs):
370
371
372
        """
        merge image inputs when requests are being merged
        """
373

374
        # args needed to be merged
375
        optional_args = [
Mick's avatar
Mick committed
376
            "mm_items",
377
            "image_pad_len",
378
379
        ]
        for arg in optional_args:
380
381
382
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
383
384
385
386
387
388
389
390
391
392

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

393
394
395
396
397
398
399
400
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
401
402
403
404
405
406

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
407
        # other args would be kept intact
408

Liangsheng Yin's avatar
Liangsheng Yin committed
409

Lianmin Zheng's avatar
Lianmin Zheng committed
410
class Req:
411
    """The input and output status of a request."""
412

413
414
415
416
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
417
        origin_input_ids: List[int],
418
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
419
420
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
421
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
422
        stream: bool = False,
423
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
424
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
425
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
426
        token_type_ids: List[int] = None,
427
        session_id: Optional[str] = None,
428
        custom_logit_processor: Optional[str] = None,
429
        return_hidden_states: bool = False,
430
        eos_token_ids: Optional[Set[int]] = None,
431
        bootstrap_host: Optional[str] = None,
432
        bootstrap_port: Optional[int] = None,
433
        bootstrap_room: Optional[int] = None,
434
        data_parallel_rank: Optional[int] = None,
435
        vocab_size: Optional[int] = None,
436
    ):
437
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
439
        self.origin_input_text = origin_input_text
440
441
442
443
444
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
445
        self.origin_input_ids = origin_input_ids
446
447
448
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
449
        self.fill_ids = []
450
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
451
        self.input_embeds = input_embeds
452

woodx's avatar
woodx committed
453
454
455
        # for corss-endoder model
        self.token_type_ids = token_type_ids

tarinkk's avatar
tarinkk committed
456
457
458
        # The length of KV that have been removed in local attention chunked prefill
        self.evicted_seqlen_local = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
459
        # Sampling info
460
461
462
463
464
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
465
        self.sampling_params = sampling_params
466
        self.custom_logit_processor = custom_logit_processor
467
        self.return_hidden_states = return_hidden_states
468
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
469

470
        # Memory pool info
471
        self.req_pool_idx: Optional[int] = None
472

473
474
475
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
        # Whether this request has finished output
        self.finished_output = None
478
479
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
480
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
481
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
482
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
483
        self.stream = stream
484
        self.eos_token_ids = eos_token_ids
485
        self.vocab_size = vocab_size
486

487
        # For incremental decoding
488
489
490
491
492
493
494
495
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
496
497
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
498
        self.decoded_text = ""
499

500
        # For multimodal inputs
Mick's avatar
Mick committed
501
        self.multimodal_inputs: Optional[MultimodalInputs] = None
502

503
        # Prefix info
504
        # The indices to kv cache for the shared prefix.
505
        self.prefix_indices: torch.Tensor = []
506
        # Number of tokens to run prefill.
507
        self.extend_input_len = 0
508
509
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
510
511
512
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Hanming Lu's avatar
Hanming Lu committed
513
514
        # The node to lock until for swa radix tree lock ref
        self.swa_uuid_for_lock: Optional[int] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
515

516
517
518
519
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
520

521
522
523
        # For retraction
        self.is_retracted = False

524
525
526
527
528
529
530
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

531
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
532
        self.return_logprob = return_logprob
533
        # Start index to compute logprob from.
534
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
535
        self.top_logprobs_num = top_logprobs_num
536
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
537
538
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
539

540
        # Logprobs (return values)
541
542
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
543
544
545
546
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
547
548
549
550
551
552
553
554
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556

        if return_logprob:
557
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
560
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
563
564
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
568
569
570
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
571
        self.hidden_states: List[List[float]] = []
572
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
573

574
        # Embedding (return values)
575
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
576

577
        # Constrained decoding
578
        self.grammar: Optional[BaseGrammarObject] = None
579
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
580

581
        # The number of cached tokens that were already cached in the KV cache
582
        self.cached_tokens = 0
583
        self.already_computed = 0
584

585
586
587
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
588
589
590
591
592
593

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
594

Byron Hsu's avatar
Byron Hsu committed
595
        # For disaggregation
596
        self.bootstrap_host: str = bootstrap_host
597
        self.bootstrap_port: Optional[int] = bootstrap_port
598
        self.bootstrap_room: Optional[int] = bootstrap_room
599
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
600

601
602
603
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
604
605
606
607
608
609
610
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

611
612
613
614
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
615
        self.metadata_buffer_index: int = -1
616

617
618
619
620
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

621
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
622
623
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
624
        else:
Mick's avatar
Mick committed
625
            self.multimodal_inputs.merge(image_inputs)
626

627
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
628
        # Whether request reached finished condition
629
630
        return self.finished_reason is not None

631
632
633
634
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
    ):
635
        self.fill_ids = self.origin_input_ids + self.output_ids
636
        if tree_cache is not None:
637
638
639
640
641
642
643
644
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
                key=self.adjust_max_prefix_ids(),
            )
645
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
646

647
    def adjust_max_prefix_ids(self):
648
649
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
650
651
652
653

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
654
655
656
657
658

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

659
        if self.return_logprob:
660
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
661

662
        max_prefix_len = max(max_prefix_len, 0)
663
        return self.fill_ids[:max_prefix_len]
664

Liangsheng Yin's avatar
Liangsheng Yin committed
665
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
666
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
667
668
669
670
671
672
673
674
675
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
676
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
677

678
    def check_finished(self):
679
        if self.finished():
680
681
            return

682
        if self.to_abort:
683
684
685
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
686
687
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
688
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
689
690
691
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
692
693
            return

694
695
696
697
698
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

699
        last_token_id = self.output_ids[-1]
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
718

719
720
721
722
723
724
725
726
        if last_token_id > self.vocab_size or last_token_id < 0:
            if self.sampling_params.stop_token_ids:
                self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
            if self.eos_token_ids:
                self.output_ids[-1] = next(iter(self.eos_token_ids))
            self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
            return

727
        # Check stop strings
728
729
730
731
732
733
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
734
                if stop_str in tail_str or stop_str in self.decoded_text:
735
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
736
737
                    return

738
739
740
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
Hanming Lu's avatar
Hanming Lu committed
741
        self.swa_uuid_for_lock = None
742
743
        self.extend_input_len = 0
        self.is_retracted = True
744
745
746
747
748
749
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
750
        self.already_computed = 0
751

Lianmin Zheng's avatar
Lianmin Zheng committed
752
753
754
755
756
757
758
759
760
761
762
763
764
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

765
766
767
768
769
770
771
772
773
774
775
776
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

777
778
779
780
781
782
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
783
        self.return_logprob = False
784
785
786
787
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
788
    def __repr__(self):
789
        return (
790
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
791
792
793
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
794
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
795
796


Lianmin Zheng's avatar
Lianmin Zheng committed
797
# Batch id
798
799
800
bid = 0


801
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
802
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
803
    """Store all information of a batch on the scheduler."""
804

805
    # Request, memory pool, and cache
806
    reqs: List[Req]
807
    req_to_token_pool: ReqToTokenPool = None
808
    token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
809
    tree_cache: BasePrefixCache = None
Hanming Lu's avatar
Hanming Lu committed
810
    is_hybrid: bool = False
811

812
    # Batch configs
813
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
814
    forward_mode: ForwardMode = None
815
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
816
817
818
819
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
820

821
822
823
    # Events
    launch_done: Optional[threading.Event] = None

824
825
826
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

827
    # Sampling info
828
    sampling_info: SamplingBatchInfo = None
829
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
830

831
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
832
    input_ids: torch.Tensor = None  # shape: [b], int64
833
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
834
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
835
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
836
    seq_lens: torch.Tensor = None  # shape: [b], int64
837
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
838
839
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
840

841
842
843
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

844
845
846
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
847
848
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
849
    global_num_tokens_for_logprob: Optional[List[int]] = None
850
    is_extend_in_batch: bool = False
851
    can_run_dp_cuda_graph: bool = False
852
853
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
854

855
    # For processing logprobs
856
    return_logprob: bool = False
857
    top_logprobs_nums: Optional[List[int]] = None
858
    token_ids_logprobs: Optional[List[List[int]]] = None
859

Lianmin Zheng's avatar
Lianmin Zheng committed
860
861
862
863
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

864
865
866
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
867
    extend_num_tokens: Optional[int] = None
868
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
869
    extend_logprob_start_lens: List[int] = None
870
871
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
872

Lianmin Zheng's avatar
Lianmin Zheng committed
873
    # For encoder-decoder architectures
874
875
876
877
878
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

879
880
881
    # Stream
    has_stream: bool = False

882
883
    # Has grammar
    has_grammar: bool = False
884

885
    # Device
886
887
    device: str = "cuda"

888
    # Speculative decoding
889
    spec_algorithm: SpeculativeAlgorithm = None
890
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
891

892
893
894
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

895
896
897
    # Whether to return hidden states
    return_hidden_states: bool = False

898
899
900
    # hicache pointer for synchronizing data loading from CPU to GPU
    hicache_consumer_index: int = 0

901
    @classmethod
902
903
    def init_new(
        cls,
904
        reqs: List[Req],
905
        req_to_token_pool: ReqToTokenPool,
906
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
907
908
909
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
910
        spec_algorithm: SpeculativeAlgorithm,
911
        enable_custom_logit_processor: bool,
912
        chunked_req: Optional[Req] = None,
913
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
914
915
        return_logprob = any(req.return_logprob for req in reqs)

Hanming Lu's avatar
Hanming Lu committed
916
917
918
919
920
921
922
        is_hybrid = False
        if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
            assert isinstance(tree_cache, SWARadixCache) or isinstance(
                tree_cache, SWAChunkCache
            ), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
            is_hybrid = True

923
924
925
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
926
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
927
            tree_cache=tree_cache,
Hanming Lu's avatar
Hanming Lu committed
928
            is_hybrid=is_hybrid,
929
            model_config=model_config,
930
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
931
            return_logprob=return_logprob,
932
            has_stream=any(req.stream for req in reqs),
933
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
934
            device=req_to_token_pool.device,
935
            spec_algorithm=spec_algorithm,
936
            enable_custom_logit_processor=enable_custom_logit_processor,
937
            return_hidden_states=any(req.return_hidden_states for req in reqs),
938
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
939
940
        )

941
    def batch_size(self):
942
        return len(self.reqs)
943

Lianmin Zheng's avatar
Lianmin Zheng committed
944
945
946
    def is_empty(self):
        return len(self.reqs) == 0

947
    def alloc_req_slots(self, num_reqs: int):
948
949
950
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
951
952
953
954
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
955
956
957
            )
        return req_pool_indices

958
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Hanming Lu's avatar
Hanming Lu committed
959
        self._evict_tree_cache_if_needed(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
960

961
962
963
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

964
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
965
966
967
968
969
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
970
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
971
972
973
974
975
976
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

977
978
979
980
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
981
982
983
984
985
986
987

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
988
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
989
    ):
Hanming Lu's avatar
Hanming Lu committed
990
991
        num_tokens = (
            extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
992
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
Hanming Lu's avatar
Hanming Lu committed
993
994
        )
        self._evict_tree_cache_if_needed(num_tokens)
995

996
997
998
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
999
1000
1001
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
1002
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1003
1004
1005
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1006
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1007
1008
1009
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1010
1011
1012
1013
1014

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1015
1016
1017
1018
1019

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1020
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
    ):
Hanming Lu's avatar
Hanming Lu committed
1022
1023
1024
        num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size

        self._evict_tree_cache_if_needed(num_tokens)
1025

1026
1027
1028
1029
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1030
1031
1032
1033
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
Hanming Lu's avatar
Hanming Lu committed
1034
                f"{self._available_and_evictable_str()}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1035
1036
1037
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1038
1039
1040
1041
1042

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1043

1044
1045
1046
1047
1048
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1049
            im = req.multimodal_inputs
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1061
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1074
                # NOTE: the encoder part should be considered as a whole
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1093
1094
            self.device, non_blocking=True
        )
1095
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1096
1097
1098
1099
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1100
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1101
1102
1103
1104
1105
1106
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1107
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1108
1109
1110
1111
1112
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1113
1114
1115
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1116

1117
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1118
1119
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1120
        # Allocate req slots
1121
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1122
1123
1124
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
        reqs = self.reqs
1126
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1127
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1131

woodx's avatar
woodx committed
1132
1133
1134
1135
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
woodx's avatar
woodx committed
1148
1149
1150
1151
1152
1153
1154

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

Lianmin Zheng's avatar
Lianmin Zheng committed
1155
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1156

Lianmin Zheng's avatar
Lianmin Zheng committed
1157
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1158
        input_embeds = []
1159
        extend_input_logprob_token_ids = []
1160
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1161

Lianmin Zheng's avatar
Lianmin Zheng committed
1162
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1163
            req.req_pool_idx = req_pool_indices[i]
1164
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1165

1166
            if pre_len > 0:
1167
1168
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1169
                )
tarinkk's avatar
tarinkk committed
1170
                if isinstance(self.tree_cache, SWAChunkCache):
Hanming Lu's avatar
Hanming Lu committed
1171
                    self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1172
1173
                        req, pre_len, self.model_config.attention_chunk_size
                    )
1174

Rin Intachuen's avatar
Rin Intachuen committed
1175
1176
1177
1178
1179
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1180
1181
            multimodal_inputs.append(req.multimodal_inputs)

1182
1183
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1184
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1185

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1240

Lianmin Zheng's avatar
Lianmin Zheng committed
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1254
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1255
1256
1257
1258
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1259
1260
1261
1262
1263
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1264
1265
1266
1267
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
1268
                pixel_values = getattr(mm_item, "feature", None)
1269
                if isinstance(pixel_values, torch.Tensor):
1270
                    mm_item.feature = pixel_values.to(self.device, non_blocking=True)
1271
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1272
        self.token_type_ids = token_type_ids_tensor
1273
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1274

1275
1276
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1277
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1278

1279
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1280
1281
1282
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1283
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1284

1285
        # Write to req_to_token_pool
1286
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1287
1288
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1289
1290
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
1292
1293
1294
1295
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1296
1297
1298
1299
1300
1301
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1302
1303
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1304
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1305
                pt += extend_lens[i]
1306

1307
1308
1309
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1310
        # Build sampling info
1311
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1312
1313
            self,
            self.model_config.vocab_size,
1314
        )
1315

1316
1317
1318
1319
1320
    def prepare_for_split_prefill(self):
        self.prepare_for_extend()
        # For split prefill, we need to set the forward mode to SPLIT_PREFILL
        self.forward_mode = ForwardMode.SPLIT_PREFILL

1321
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1322
        self.forward_mode = ForwardMode.MIXED
1323
        running_bs = running_batch.batch_size()
1324
1325
1326
1327
1328

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1329
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1330
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1331

1332
        self.merge_batch(running_batch)
1333
1334
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1335

1336
1337
1338
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1339
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1340
        self.prefix_lens.extend(
1341
            [
1342
                len(r.origin_input_ids) + len(r.output_ids) + delta
1343
1344
1345
                for r in running_batch.reqs
            ]
        )
1346
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
1348
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1349
        self.extend_logprob_start_lens.extend([0] * running_bs)
1350

1351
1352
1353
1354
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1355
1356
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
pansicheng's avatar
pansicheng committed
1357
1358
1359
1360
1361
        return (
            sum(1 for req in self.reqs if req.seqlen % page_size == 0)
            if self.enable_overlap
            else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
        )
1362

1363
    def check_decode_mem(self, buf_multiplier=1):
Hanming Lu's avatar
Hanming Lu committed
1364
        num_tokens = (
1365
1366
1367
1368
1369
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )

Hanming Lu's avatar
Hanming Lu committed
1370
1371
        self._evict_tree_cache_if_needed(num_tokens)
        return self._is_available_size_sufficient(num_tokens)
1372

1373
    def retract_decode(self, server_args: ServerArgs):
1374
        """Retract the decoding requests when there is not enough memory."""
1375
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1376
1377

        # TODO(lsyin): improve retraction policy for radix cache
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1403

Hanming Lu's avatar
Hanming Lu committed
1404
1405
1406
1407
1408
1409
1410
1411
1412
        def _get_available_size():
            if self.is_hybrid:
                return min(
                    self.token_to_kv_pool_allocator.full_available_size(),
                    self.token_to_kv_pool_allocator.swa_available_size(),
                )
            else:
                return self.token_to_kv_pool_allocator.available_size()

Lianmin Zheng's avatar
Lianmin Zheng committed
1413
1414
1415
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1416
        while (
Hanming Lu's avatar
Hanming Lu committed
1417
            _get_available_size() < get_required_tokens(len(sorted_indices))
1418
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1419
1420
1421
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
Hanming Lu's avatar
Hanming Lu committed
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
                if self.is_hybrid:
                    full_available_size = (
                        self.token_to_kv_pool_allocator.full_available_size()
                    )
                    swa_available_size = (
                        self.token_to_kv_pool_allocator.swa_available_size()
                    )
                    assert (
                        full_available_size > 0 and swa_available_size > 0
                    ), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
                else:
                    assert (
                        self.token_to_kv_pool_allocator.available_size() > 0
                    ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1436
1437
                break

1438
            first_iter = False
1439
1440
1441
1442
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1443
1444
1445
1446
1447
            if server_args.disaggregation_mode == "decode":
                req.offload_kv_cache(
                    self.req_to_token_pool, self.token_to_kv_pool_allocator
                )

1448
1449
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1450
1451
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1452
                ]
1453
                self.token_to_kv_pool_allocator.free(token_indices)
1454
                self.req_to_token_pool.free(req.req_pool_idx)
1455
1456
            else:
                # TODO: apply more fine-grained retraction
1457
                last_uncached_pos = (
1458
1459
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1460
1461
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1462
                ]
1463
                self.token_to_kv_pool_allocator.free(token_indices)
1464
                self.req_to_token_pool.free(req.req_pool_idx)
1465
1466

                # release the last node
Hanming Lu's avatar
Hanming Lu committed
1467
1468
1469
1470
                if self.is_hybrid:
                    self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
                else:
                    self.tree_cache.dec_lock_ref(req.last_node)
1471
1472

                # NOTE(lsyin): we should use the newly evictable memory instantly.
Hanming Lu's avatar
Hanming Lu committed
1473
1474
                num_tokens = len(sorted_indices) * global_config.retract_decode_steps
                self._evict_tree_cache_if_needed(num_tokens)
1475

1476
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1477

1478
1479
1480
1481
1482
1483
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1484
        self.filter_batch(keep_indices=sorted_indices)
1485

Liangsheng Yin's avatar
Liangsheng Yin committed
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1496

1497
1498
1499
1500
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1501
1502
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1504
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1505
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1506
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1507
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1508
        self.extend_num_tokens = 0
1509
1510
1511
1512
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1513

1514
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1515
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
1517
        bs = len(self.reqs)

1518
        if self.spec_algorithm.is_eagle():
1519
1520
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1521
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1522

1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1546
        # Update fields
1547
1548
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1549

1550
1551
1552
1553
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1554
            locs = self.seq_lens.clone()
1555

1556
        if self.enable_overlap:
1557
1558
1559
1560
1561
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1562
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1563

tarinkk's avatar
tarinkk committed
1564
1565
1566
        # free memory
        if isinstance(self.tree_cache, SWAChunkCache):
            for req in self.reqs:
Hanming Lu's avatar
Hanming Lu committed
1567
                self.tree_cache.evict_swa(
tarinkk's avatar
tarinkk committed
1568
1569
1570
                    req, req.seqlen - 1, self.model_config.attention_chunk_size
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1586
1587
    def filter_batch(
        self,
1588
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1589
1590
1591
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1592
1593
1594
1595
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1596
1597
1598
            keep_indices = [
                i
                for i in range(len(self.reqs))
1599
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1600
                and self.reqs[i] not in chunked_req_to_exclude
1601
1602
1603
            ]

        if keep_indices is None or len(keep_indices) == 0:
1604
1605
1606
1607
            # Filter out all requests
            self.reqs = []
            return

1608
        if len(keep_indices) == len(self.reqs):
1609
1610
1611
            # No need to filter
            return

1612
1613
1614
1615
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1616
        if self.model_config.is_encoder_decoder:
1617
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1618
1619
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1620
        self.reqs = [self.reqs[i] for i in keep_indices]
1621
1622
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1623
1624
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1625
        self.out_cache_loc = None
1626
        self.seq_lens_sum = self.seq_lens.sum().item()
1627
        self.output_ids = self.output_ids[keep_indices_device]
1628
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1629
        if self.return_logprob:
1630
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1631
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1632
1633
        else:
            self.top_logprobs_nums = None
1634
            self.token_ids_logprobs = None
1635

1636
        self.has_stream = any(req.stream for req in self.reqs)
1637
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1638

1639
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1640
        if self.spec_info:
1641
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1642

1643
    def merge_batch(self, other: "ScheduleBatch"):
1644
1645
1646
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1647
        self.sampling_info.merge_batch(other.sampling_info)
1648

1649
1650
1651
1652
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1653
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1654
1655
            [self.req_pool_indices, other.req_pool_indices]
        )
1656
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1657
        self.out_cache_loc = None
1658
        self.seq_lens_sum += other.seq_lens_sum
1659
        if self.output_ids is not None:
1660
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1661
1662
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1663
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1664
1665
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1666
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1667
1668
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1669
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1670
        self.reqs.extend(other.reqs)
1671
1672
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1673

1674
1675
1676
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1677
        self.return_hidden_states |= other.return_hidden_states
1678

1679
1680
1681
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1682
1683
1684
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1685
        if self.forward_mode.is_decode_or_idle():
1686
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1687
1688
1689
1690
1691
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1692
1693
        # Create seq_lens_cpu when needed
        if (
1694
1695
            global_server_args_dict["attention_backend"] == "fa3"
            or (
1696
1697
1698
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1699
            or global_server_args_dict["attention_backend"] == "flashmla"
1700
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1701
            or global_server_args_dict["attention_backend"] == "ascend"
1702
            or global_server_args_dict["enable_two_batch_overlap"]
1703
        ):
1704
1705
1706
1707
1708
            seq_lens_cpu = (
                seq_lens_cpu_cache
                if seq_lens_cpu_cache is not None
                else self.seq_lens.cpu()
            )
1709
1710
1711
        else:
            seq_lens_cpu = None

1712
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1713
1714
1715
1716
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1717

1718
1719
        global bid
        bid += 1
1720
        return ModelWorkerBatch(
1721
            bid=bid,
1722
1723
1724
1725
1726
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1727
            seq_lens_cpu=seq_lens_cpu,
1728
            seq_lens_sum=self.seq_lens_sum,
1729
1730
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1731
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1732
            global_num_tokens=self.global_num_tokens,
1733
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1734
            is_extend_in_batch=self.is_extend_in_batch,
1735
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1736
1737
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1738
            extend_num_tokens=self.extend_num_tokens,
1739
1740
1741
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1742
            multimodal_inputs=self.multimodal_inputs,
1743
1744
1745
1746
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1747
            lora_paths=[req.lora_path for req in self.reqs],
1748
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1749
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1750
            token_type_ids=self.token_type_ids,
1751
1752
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1753
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1754
            capture_hidden_mode=(
1755
                CaptureHiddenMode.FULL
1756
                if self.return_hidden_states
1757
1758
1759
1760
1761
1762
1763
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1764
            ),
1765
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1766
            launch_done=self.launch_done,
1767
1768
        )

1769
    def copy(self):
1770
        # Only contain fields that will be used by process_batch_result
1771
1772
        return ScheduleBatch(
            reqs=self.reqs,
1773
            model_config=self.model_config,
1774
            forward_mode=self.forward_mode,
1775
1776
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1777
            decoding_reqs=self.decoding_reqs,
1778
            spec_algorithm=self.spec_algorithm,
1779
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1780
1781
1782
1783
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1784
1785
        )

Hanming Lu's avatar
Hanming Lu committed
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
    def _evict_tree_cache_if_needed(
        self,
        num_tokens: int,
    ) -> None:
        if isinstance(self.tree_cache, SWAChunkCache):
            return

        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()

            if full_available_size < num_tokens or swa_available_size < num_tokens:
                if self.tree_cache is not None:
                    full_num_tokens = max(0, num_tokens - full_available_size)
                    swa_num_tokens = max(0, num_tokens - swa_available_size)
                    self.tree_cache.evict(full_num_tokens, swa_num_tokens)
        else:
            if self.token_to_kv_pool_allocator.available_size() < num_tokens:
                if self.tree_cache is not None:
                    self.tree_cache.evict(num_tokens)

    def _is_available_size_sufficient(self, num_tokens: int) -> bool:
        if self.is_hybrid:
            return (
                self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
                and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
            )
        else:
            return self.token_to_kv_pool_allocator.available_size() >= num_tokens

    def _available_and_evictable_str(self) -> str:
        if self.is_hybrid:
            full_available_size = self.token_to_kv_pool_allocator.full_available_size()
            swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
            full_evictable_size = self.tree_cache.full_evictable_size()
            swa_evictable_size = self.tree_cache.swa_evictable_size()
            return (
                f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
                f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
                f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
                f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
            )
        else:
            available_size = self.token_to_kv_pool_allocator.available_size()
            evictable_size = self.tree_cache.evictable_size()
            return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"

1833
1834
    def __str__(self):
        return (
1835
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1836
1837
1838
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1839

1840
@dataclasses.dataclass
1841
class ModelWorkerBatch:
1842
1843
    # The batch id
    bid: int
1844
1845
1846
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1847
    input_ids: torch.Tensor
1848
1849
1850
1851
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1852
    # The indices of output tokens in the token_to_kv_pool_allocator
1853
    out_cache_loc: torch.Tensor
1854
1855
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1856
1857
    seq_lens_sum: int

1858
1859
1860
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1861
    token_ids_logprobs: Optional[List[List[int]]]
1862

Ke Bao's avatar
Ke Bao committed
1863
1864
    # For DP attention
    global_num_tokens: Optional[List[int]]
1865
    global_num_tokens_for_logprob: Optional[List[int]]
1866
    is_extend_in_batch: bool
1867
    can_run_dp_cuda_graph: bool
1868
1869
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1870

1871
    # For extend
1872
    extend_num_tokens: Optional[int]
1873
1874
1875
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1876
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1877
1878

    # For multimodal
Mick's avatar
Mick committed
1879
    multimodal_inputs: Optional[List[MultimodalInputs]]
1880

1881
1882
1883
1884
1885
1886
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1887
1888
1889
1890
1891
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1892

Rin Intachuen's avatar
Rin Intachuen committed
1893
    # The input Embeds
Cheng Wan's avatar
Cheng Wan committed
1894
    input_embeds: Optional[torch.Tensor] = None
Rin Intachuen's avatar
Rin Intachuen committed
1895

woodx's avatar
woodx committed
1896
1897
1898
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1899
    # Speculative decoding
1900
    spec_algorithm: SpeculativeAlgorithm = None
1901
1902
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1903
    capture_hidden_mode: CaptureHiddenMode = None
1904
    hicache_consumer_index: int = 0
1905

1906
1907
1908
    # Overlap event
    launch_done: Optional[threading.Event] = None

1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1927
1928
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1945
1946


1947
1948
1949
1950
1951
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
1952
1953
1954
1955
    if (
        global_server_args_dict["attention_backend"] != "ascend"
        and global_server_args_dict["attention_backend"] != "torch_native"
    ):
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1968
1969
1970
1971
1972
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result