"torchvision/vscode:/vscode.git/clone" did not exist on "bb2805a669b197582b796f0d86ef8932e1a4396a"
schedule_batch.py 70.8 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from http import HTTPStatus
41
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
44
import torch
45
46
import triton
import triton.language as tl
47

Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.global_config import global_config
49
from sglang.srt.configs.model_config import ModelConfig
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Mick's avatar
Mick committed
56
from sglang.srt.layers.multimodal import gpu_tensor_hash
57
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
61
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
from sglang.srt.sampling.sampling_params import SamplingParams
64
from sglang.srt.server_args import ServerArgs
65
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
if TYPE_CHECKING:
68
69
70
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
71
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
72

73
74
GLOBAL_SERVER_ARGS_KEYS = [
    "attention_backend",
75
    "mm_attention_backend",
76
77
78
79
80
81
82
83
84
    "debug_tensor_dump_inject",
    "debug_tensor_dump_output_folder",
    "chunked_prefill_size",
    "device",
    "disable_chunked_prefix_cache",
    "disable_radix_cache",
    "enable_dp_attention",
    "enable_two_batch_overlap",
    "enable_dp_lm_head",
85
86
    "enable_deepep_moe",
    "deepep_mode",
87
    "enable_ep_moe",
88
89
    "moe_dense_tp_size",
    "ep_dispatch_algorithm",
90
    "deepep_config",
91
    "ep_num_redundant_experts",
92
93
94
95
96
97
98
99
100
    "enable_nan_detection",
    "flashinfer_mla_disable_ragged",
    "max_micro_batch_size",
    "disable_shared_experts_fusion",
    "sampling_backend",
    "speculative_accept_threshold_acc",
    "speculative_accept_threshold_single",
    "torchao_config",
    "triton_attention_reduce_in_fp32",
101
    "num_reserved_decode_tokens",
102
103
]

104
# Put some global args for easy access
105
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
106

Ying Sheng's avatar
Ying Sheng committed
107
108
109
logger = logging.getLogger(__name__)


110
111
112
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
113

114
    def to_json(self):
115
        raise NotImplementedError()
116
117
118


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
119
    def __init__(self, matched: Union[int, List[int]]):
120
121
122
        super().__init__()
        self.matched = matched

123
124
125
126
127
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
128
129


130
131
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
132
        super().__init__()
133
        self.matched = matched
134

135
136
137
138
139
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
140
141


142
143
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
144
        super().__init__()
145
        self.length = length
146

147
148
149
150
151
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
152
153
154


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
155
    def __init__(self, message=None, status_code=None, err_type=None):
156
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
157
        self.message = message or "Aborted"
158
159
        self.status_code = status_code
        self.err_type = err_type
160

161
162
163
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
164
            "message": self.message,
165
166
            "status_code": self.status_code,
            "err_type": self.err_type,
167
        }
168

Lianmin Zheng's avatar
Lianmin Zheng committed
169

Mick's avatar
Mick committed
170
171
172
173
174
175
176
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


177
@dataclasses.dataclass
Mick's avatar
Mick committed
178
179
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
180
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
181
    """
182

Mick's avatar
Mick committed
183
184
185
186
187
188
189
190
191
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
192
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
193
194

    # the real data, pixel_values or audio_features
195
196
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
197
    image_grid_thw: Union[torch.Tensor, np.ndarray] = None
198
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
199
200
201
202
203
204
205
206

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

207
208
209
    # kimi-vl related
    image_grid_hws: Optional[List[torch.Tensor]] = None

210
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
211
    audio_feature_lens: Optional[List[torch.Tensor]] = None
212
    audio_offsets: Optional[List[Tuple[int, int]]] = None
Mick's avatar
Mick committed
213

214
215
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
216
217
218
219
220
221
222
223
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
224
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
225
226
        """

Mick's avatar
Mick committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
243
244
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
245
246
247
248
249
250
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

251
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
252
            if tensor.is_cuda:
253
254
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
255
256
257
258
259
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
260

Mick's avatar
Mick committed
261
262
        def hash_feature(f):
            if isinstance(f, list):
263
264
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
265
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
266
267
268
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
269
270
271
272
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
273

274
275
276
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
277
278
279
280
281
282
283
284
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
285
286
287
288
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
289
290
291
292

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
293
294
295
296
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
297
298

    def is_video(self):
299
300
301
302
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
303

304
305
306
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
307
308
309
310
    def validate(self):
        ...
        # TODO

311
312
313
314
315
316
317
318
319
320
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
321
322
323
324
325
326
327

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
328
    image_pad_len: Optional[list] = None
329
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
330

Yineng Zhang's avatar
Yineng Zhang committed
331
    # QWen2-VL related
332
    mrope_positions: Optional[torch.Tensor] = None
333
    mrope_position_delta: Optional[torch.Tensor] = None
334

Mick's avatar
Mick committed
335
    # image
Mick's avatar
Mick committed
336
    im_token_id: Optional[int] = None
337
338
339
340
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
341
342
343

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
344

Mick's avatar
Mick committed
345
    # audio
346
347
348
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
349

Liangsheng Yin's avatar
Liangsheng Yin committed
350
    @staticmethod
351
    def from_dict(obj: dict):
Mick's avatar
Mick committed
352
        ret = MultimodalInputs(
Mick's avatar
Mick committed
353
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
354
        )
355

Mick's avatar
Mick committed
356
        assert isinstance(ret.mm_items, list)
357
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
358
359
360

        for item in ret.mm_items:
            item.set_pad_value()
361
362

        optional_args = [
363
364
            "mrope_positions",
            "mrope_position_delta",
365
            "im_token_id",
Mick's avatar
Mick committed
366
367
368
369
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
370
371
            "audio_start_id",
            "audio_end_id",
372
            "audio_token_id",
373
374
375
376
377
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
378
379
        return ret

Mick's avatar
Mick committed
380
381
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
382
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
383
384
385

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
386
387
        return any(item.is_audio() for item in self.mm_items)

388
389
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
390
391

    def merge(self, other: MultimodalInputs):
392
393
394
        """
        merge image inputs when requests are being merged
        """
395

396
        # args needed to be merged
397
        optional_args = [
Mick's avatar
Mick committed
398
            "mm_items",
399
            "image_pad_len",
400
401
        ]
        for arg in optional_args:
402
403
404
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
405
406
407
408
409
410
411
412
413
414

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

415
416
417
418
419
420
421
422
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
423
424
425
426
427
428

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
429
        # other args would be kept intact
430

Liangsheng Yin's avatar
Liangsheng Yin committed
431

Lianmin Zheng's avatar
Lianmin Zheng committed
432
class Req:
433
    """The input and output status of a request."""
434

435
436
437
438
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
439
        origin_input_ids: List[int],
440
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
443
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        stream: bool = False,
445
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
446
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
447
        input_embeds: Optional[List[List[float]]] = None,
woodx's avatar
woodx committed
448
        token_type_ids: List[int] = None,
449
        session_id: Optional[str] = None,
450
        custom_logit_processor: Optional[str] = None,
451
        return_hidden_states: bool = False,
452
        eos_token_ids: Optional[Set[int]] = None,
453
        bootstrap_host: Optional[str] = None,
454
        bootstrap_port: Optional[int] = None,
455
        bootstrap_room: Optional[int] = None,
456
        data_parallel_rank: Optional[int] = None,
457
    ):
458
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
459
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
460
        self.origin_input_text = origin_input_text
461
462
463
464
465
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
466
        self.origin_input_ids = origin_input_ids
467
468
469
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
470
        self.fill_ids = []
471
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
472
        self.input_embeds = input_embeds
473

woodx's avatar
woodx committed
474
475
476
        # for corss-endoder model
        self.token_type_ids = token_type_ids

Lianmin Zheng's avatar
Lianmin Zheng committed
477
        # Sampling info
478
479
480
481
482
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
483
        self.sampling_params = sampling_params
484
        self.custom_logit_processor = custom_logit_processor
485
        self.return_hidden_states = return_hidden_states
486
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
487

488
        # Memory pool info
489
        self.req_pool_idx: Optional[int] = None
490

491
492
493
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
        # Whether this request has finished output
        self.finished_output = None
496
497
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
498
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
499
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
500
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
501
        self.stream = stream
502
        self.eos_token_ids = eos_token_ids
503

504
        # For incremental decoding
505
506
507
508
509
510
511
512
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
513
514
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
515
        self.decoded_text = ""
516

517
        # For multimodal inputs
Mick's avatar
Mick committed
518
        self.multimodal_inputs: Optional[MultimodalInputs] = None
519

520
        # Prefix info
521
        # The indices to kv cache for the shared prefix.
522
        self.prefix_indices: torch.Tensor = []
523
        # Number of tokens to run prefill.
524
        self.extend_input_len = 0
525
526
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
527
528
529
        self.last_node: Any = None
        self.last_host_node: Any = None
        self.host_hit_length = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
530

531
532
533
534
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
535

536
537
538
        # For retraction
        self.is_retracted = False

539
540
541
542
543
544
545
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

546
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
547
        self.return_logprob = return_logprob
548
        # Start index to compute logprob from.
549
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
550
        self.top_logprobs_num = top_logprobs_num
551
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
554

555
        # Logprobs (return values)
556
557
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
558
559
560
561
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
562
563
564
565
566
567
568
569
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
570
571

        if return_logprob:
572
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
575
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
576
577
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
578
579
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
580
581
582
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
583
584
585
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
586
        self.hidden_states: List[List[float]] = []
587
        self.hidden_states_tensor = None  # Note: use tensor instead of list to transfer hidden_states when PD + MTP
588

589
        # Embedding (return values)
590
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
591

592
        # Constrained decoding
593
        self.grammar: Optional[BaseGrammarObject] = None
594
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
595

596
        # The number of cached tokens that were already cached in the KV cache
597
        self.cached_tokens = 0
598
        self.already_computed = 0
599

600
601
602
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
603
604
605
606
607
608

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
609

Byron Hsu's avatar
Byron Hsu committed
610
        # For disaggregation
611
        self.bootstrap_host: str = bootstrap_host
612
        self.bootstrap_port: Optional[int] = bootstrap_port
613
        self.bootstrap_room: Optional[int] = bootstrap_room
614
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
615

616
617
618
        # For data parallel rank routing
        self.data_parallel_rank: Optional[int] = data_parallel_rank

Byron Hsu's avatar
Byron Hsu committed
619
620
621
622
623
624
625
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

626
627
628
629
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
630
        self.metadata_buffer_index: int = -1
631

632
633
634
635
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

636
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
637
638
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
639
        else:
Mick's avatar
Mick committed
640
            self.multimodal_inputs.merge(image_inputs)
641

642
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
643
        # Whether request reached finished condition
644
645
        return self.finished_reason is not None

646
647
648
649
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
    ):
650
        self.fill_ids = self.origin_input_ids + self.output_ids
651
        if tree_cache is not None:
652
653
654
655
656
657
658
659
            (
                self.prefix_indices,
                self.last_node,
                self.last_host_node,
                self.host_hit_length,
            ) = tree_cache.match_prefix(
                key=self.adjust_max_prefix_ids(),
            )
660
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
661

662
    def adjust_max_prefix_ids(self):
663
664
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
665
666
667
668

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
669
670
671
672
673

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

674
        if self.return_logprob:
675
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
676

677
        max_prefix_len = max(max_prefix_len, 0)
678
        return self.fill_ids[:max_prefix_len]
679

Liangsheng Yin's avatar
Liangsheng Yin committed
680
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
681
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
682
683
684
685
686
687
688
689
690
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
691
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
692

693
    def check_finished(self):
694
        if self.finished():
695
696
            return

697
        if self.to_abort:
698
699
700
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
701
702
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
703
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
704
705
706
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
707
708
            return

709
710
711
712
713
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

714
        last_token_id = self.output_ids[-1]
715

716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
733

734
        # Check stop strings
735
736
737
738
739
740
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
741
                if stop_str in tail_str or stop_str in self.decoded_text:
742
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
743
744
                    return

745
746
747
748
749
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
750
751
752
753
754
755
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
756
        self.already_computed = 0
757

Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
760
761
762
763
764
765
766
767
768
769
770
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

771
772
773
774
775
776
777
778
779
780
781
782
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

783
784
785
786
787
788
789
790
791
792
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
793
    def __repr__(self):
794
        return (
795
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797
798
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
799
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
800
801


Lianmin Zheng's avatar
Lianmin Zheng committed
802
# Batch id
803
804
805
bid = 0


806
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
807
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
808
    """Store all information of a batch on the scheduler."""
809

810
    # Request, memory pool, and cache
811
    reqs: List[Req]
812
    req_to_token_pool: ReqToTokenPool = None
813
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
814
    tree_cache: BasePrefixCache = None
815

816
    # Batch configs
817
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
818
    forward_mode: ForwardMode = None
819
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
820
821
822
823
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
824

825
826
827
    # Events
    launch_done: Optional[threading.Event] = None

828
829
830
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

831
    # Sampling info
832
    sampling_info: SamplingBatchInfo = None
833
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
834

835
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
836
    input_ids: torch.Tensor = None  # shape: [b], int64
837
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
woodx's avatar
woodx committed
838
    token_type_ids: torch.Tensor = None  # shape: [b], int64
Lianmin Zheng's avatar
Lianmin Zheng committed
839
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
840
    seq_lens: torch.Tensor = None  # shape: [b], int64
841
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
842
843
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
844

845
846
847
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

848
849
850
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
851
852
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
853
    global_num_tokens_for_logprob: Optional[List[int]] = None
854
    can_run_dp_cuda_graph: bool = False
855
    is_extend_in_batch: bool = False
856
857
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
858

859
    # For processing logprobs
860
    return_logprob: bool = False
861
    top_logprobs_nums: Optional[List[int]] = None
862
    token_ids_logprobs: Optional[List[List[int]]] = None
863

Lianmin Zheng's avatar
Lianmin Zheng committed
864
865
866
867
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

868
869
870
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
871
    extend_num_tokens: Optional[int] = None
872
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
873
    extend_logprob_start_lens: List[int] = None
874
875
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
876

Lianmin Zheng's avatar
Lianmin Zheng committed
877
    # For encoder-decoder architectures
878
879
880
881
882
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

883
884
885
    # Stream
    has_stream: bool = False

886
887
    # Has grammar
    has_grammar: bool = False
888

889
    # Device
890
891
    device: str = "cuda"

892
    # Speculative decoding
893
    spec_algorithm: SpeculativeAlgorithm = None
894
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
895

896
897
898
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

899
900
901
    # Whether to return hidden states
    return_hidden_states: bool = False

902
903
904
    # hicache pointer for synchronizing data loading from CPU to GPU
    hicache_consumer_index: int = 0

905
    @classmethod
906
907
    def init_new(
        cls,
908
        reqs: List[Req],
909
        req_to_token_pool: ReqToTokenPool,
910
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
911
912
913
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
914
        spec_algorithm: SpeculativeAlgorithm,
915
        enable_custom_logit_processor: bool,
916
        chunked_req: Optional[Req] = None,
917
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
918
919
        return_logprob = any(req.return_logprob for req in reqs)

920
921
922
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
923
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
924
            tree_cache=tree_cache,
925
            model_config=model_config,
926
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
927
            return_logprob=return_logprob,
928
            has_stream=any(req.stream for req in reqs),
929
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
930
            device=req_to_token_pool.device,
931
            spec_algorithm=spec_algorithm,
932
            enable_custom_logit_processor=enable_custom_logit_processor,
933
            return_hidden_states=any(req.return_hidden_states for req in reqs),
934
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
935
936
        )

937
    def batch_size(self):
938
        return len(self.reqs)
939

Lianmin Zheng's avatar
Lianmin Zheng committed
940
941
942
    def is_empty(self):
        return len(self.reqs) == 0

943
    def alloc_req_slots(self, num_reqs: int):
944
945
946
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
947
948
949
950
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
951
952
953
            )
        return req_pool_indices

954
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
955
956
957
958
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

959
960
961
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

962
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
963
964
965
966
967
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
968
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
969
970
971
972
973
974
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

975
976
977
978
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
979
980
981
982
983
984
985

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
986
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
989
990
991
992
993
994
995
996
997
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
998

999
1000
1001
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
1002
1003
1004
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
1005
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1006
1007
1008
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
1009
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1010
1011
1012
1013
1014
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1015
1016
1017
1018
1019

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
1022
1023
1024

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1025
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1026
    ):
1027
1028
1029
1030
1031
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
1033
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1034
                )
1035

1036
1037
1038
1039
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1040
1041
1042
1043
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1044
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1045
1046
1047
1048
1049
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1050
1051
1052
1053
1054

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1055

1056
1057
1058
1059
1060
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1061
            im = req.multimodal_inputs
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1073
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1086
                # NOTE: the encoder part should be considered as a whole
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1105
1106
            self.device, non_blocking=True
        )
1107
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1108
1109
1110
1111
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1112
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1113
1114
1115
1116
1117
1118
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1119
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1120
1121
1122
1123
1124
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1125
1126
1127
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1128

1129
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1130
1131
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1132
        # Allocate req slots
1133
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
1136
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1137
        reqs = self.reqs
1138
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1139
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1140
1141
1142
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1143

woodx's avatar
woodx committed
1144
1145
1146
1147
        token_type_ids = [
            r.token_type_ids for r in reqs if r.token_type_ids is not None
        ]

Lianmin Zheng's avatar
Lianmin Zheng committed
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
woodx's avatar
woodx committed
1160
1161
1162
1163
1164
1165
1166

        token_type_ids_tensor = None
        if len(token_type_ids) > 0:
            token_type_ids_tensor = torch.tensor(
                sum(token_type_ids, []), dtype=torch.int64
            ).to(self.device, non_blocking=True)

Lianmin Zheng's avatar
Lianmin Zheng committed
1167
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1168

Lianmin Zheng's avatar
Lianmin Zheng committed
1169
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1170
        input_embeds = []
1171
        extend_input_logprob_token_ids = []
1172
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1173

Lianmin Zheng's avatar
Lianmin Zheng committed
1174
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1175
            req.req_pool_idx = req_pool_indices[i]
1176
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1177

1178
            if pre_len > 0:
1179
1180
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1181
                )
1182

Rin Intachuen's avatar
Rin Intachuen committed
1183
1184
1185
1186
1187
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1188
1189
            multimodal_inputs.append(req.multimodal_inputs)

1190
1191
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1192
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1193

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1248

Lianmin Zheng's avatar
Lianmin Zheng committed
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1262
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1263
1264
1265
1266
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1267
1268
1269
1270
1271
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
woodx's avatar
woodx committed
1282
        self.token_type_ids = token_type_ids_tensor
1283
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1284

1285
1286
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1287
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1288

1289
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1290
1291
1292
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1293
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1294

1295
        # Write to req_to_token_pool
1296
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
1298
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1299
1300
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
1302
1303
1304
1305
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1306
1307
1308
1309
1310
1311
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1312
1313
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1314
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1315
                pt += extend_lens[i]
1316

1317
1318
1319
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1320
        # Build sampling info
1321
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1322
1323
            self,
            self.model_config.vocab_size,
1324
        )
1325

1326
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1327
        self.forward_mode = ForwardMode.MIXED
1328
        running_bs = running_batch.batch_size()
1329
1330
1331
1332
1333

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1334
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1335
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1336

1337
        self.merge_batch(running_batch)
1338
1339
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1340

1341
1342
1343
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1344
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1345
        self.prefix_lens.extend(
1346
            [
1347
                len(r.origin_input_ids) + len(r.output_ids) + delta
1348
1349
1350
                for r in running_batch.reqs
            ]
        )
1351
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1352
1353
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1354
        self.extend_logprob_start_lens.extend([0] * running_bs)
1355

1356
1357
1358
1359
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1360
1361
1362
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
        return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1363

1364
1365
1366
1367
1368
1369
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1370

1371
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1372
1373
            return True

1374
1375
1376
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1377

1378
    def retract_decode(self, server_args: ServerArgs):
1379
        """Retract the decoding requests when there is not enough memory."""
1380
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1381
1382

        # TODO(lsyin): improve retraction policy for radix cache
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1408

Lianmin Zheng's avatar
Lianmin Zheng committed
1409
1410
1411
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1412
        while (
1413
            self.token_to_kv_pool_allocator.available_size()
1414
            < get_required_tokens(len(sorted_indices))
1415
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1416
1417
1418
1419
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1420
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1421
1422
1423
                ), "No space left for only one request"
                break

1424
            first_iter = False
1425
1426
1427
1428
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1429
1430
1431
1432
1433
            if server_args.disaggregation_mode == "decode":
                req.offload_kv_cache(
                    self.req_to_token_pool, self.token_to_kv_pool_allocator
                )

1434
1435
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1436
1437
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1438
                ]
1439
                self.token_to_kv_pool_allocator.free(token_indices)
1440
                self.req_to_token_pool.free(req.req_pool_idx)
1441
1442
            else:
                # TODO: apply more fine-grained retraction
1443
                last_uncached_pos = (
1444
1445
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1446
1447
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1448
                ]
1449
                self.token_to_kv_pool_allocator.free(token_indices)
1450
                self.req_to_token_pool.free(req.req_pool_idx)
1451
1452
1453
1454
1455
1456
1457

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1458
                    - self.token_to_kv_pool_allocator.available_size()
1459
1460
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1461
                self.tree_cache.evict(residual_size)
1462

1463
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1464

1465
1466
1467
1468
1469
1470
            if len(retracted_reqs) == 0:
                # Corner case: only one request left
                raise ValueError(
                    "Failed to retract any request. No space left for only one request."
                )

1471
        self.filter_batch(keep_indices=sorted_indices)
1472

Liangsheng Yin's avatar
Liangsheng Yin committed
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1483

1484
1485
1486
1487
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1488
1489
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1490
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1491
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1492
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1493
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1494
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1495
        self.extend_num_tokens = 0
1496
1497
1498
1499
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1500

1501
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1502
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
        bs = len(self.reqs)

1505
        if self.spec_algorithm.is_eagle():
1506
1507
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1508
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1509

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1533
        # Update fields
1534
1535
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1536

1537
1538
1539
1540
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1541
            locs = self.seq_lens.clone()
1542

1543
        if self.enable_overlap:
1544
1545
1546
1547
1548
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1549
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1550

Lianmin Zheng's avatar
Lianmin Zheng committed
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1566
1567
    def filter_batch(
        self,
1568
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1569
1570
1571
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1572
1573
1574
1575
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1576
1577
1578
            keep_indices = [
                i
                for i in range(len(self.reqs))
1579
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1580
                and self.reqs[i] not in chunked_req_to_exclude
1581
1582
1583
            ]

        if keep_indices is None or len(keep_indices) == 0:
1584
1585
1586
1587
            # Filter out all requests
            self.reqs = []
            return

1588
        if len(keep_indices) == len(self.reqs):
1589
1590
1591
            # No need to filter
            return

1592
1593
1594
1595
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1596
        if self.model_config.is_encoder_decoder:
1597
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1598
1599
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1600
        self.reqs = [self.reqs[i] for i in keep_indices]
1601
1602
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1603
1604
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1605
        self.out_cache_loc = None
1606
        self.seq_lens_sum = self.seq_lens.sum().item()
1607
        self.output_ids = self.output_ids[keep_indices_device]
1608
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1609
        if self.return_logprob:
1610
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1611
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1612
1613
        else:
            self.top_logprobs_nums = None
1614
            self.token_ids_logprobs = None
1615

1616
        self.has_stream = any(req.stream for req in self.reqs)
1617
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1618

1619
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1620
        if self.spec_info:
1621
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1622

1623
    def merge_batch(self, other: "ScheduleBatch"):
1624
1625
1626
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1627
        self.sampling_info.merge_batch(other.sampling_info)
1628

1629
1630
1631
1632
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1633
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1634
1635
            [self.req_pool_indices, other.req_pool_indices]
        )
1636
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1637
        self.out_cache_loc = None
1638
        self.seq_lens_sum += other.seq_lens_sum
1639
        if self.output_ids is not None:
1640
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1641
1642
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1643
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1644
1645
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1646
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1647
1648
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1649
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1650
        self.reqs.extend(other.reqs)
1651
1652
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1653

1654
1655
1656
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1657
        self.return_hidden_states |= other.return_hidden_states
1658

1659
1660
1661
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1662
1663
1664
    def get_model_worker_batch(
        self, seq_lens_cpu_cache: Optional[torch.Tensor] = None
    ) -> ModelWorkerBatch:
1665
        if self.forward_mode.is_decode_or_idle():
1666
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1667
1668
1669
1670
1671
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1672
1673
        # Create seq_lens_cpu when needed
        if (
1674
1675
            global_server_args_dict["attention_backend"] == "fa3"
            or (
1676
1677
1678
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1679
            or global_server_args_dict["attention_backend"] == "flashmla"
1680
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1681
            or global_server_args_dict["enable_two_batch_overlap"]
1682
        ):
1683
1684
1685
1686
1687
            seq_lens_cpu = (
                seq_lens_cpu_cache
                if seq_lens_cpu_cache is not None
                else self.seq_lens.cpu()
            )
1688
1689
1690
        else:
            seq_lens_cpu = None

1691
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1692
1693
1694
1695
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1696

1697
1698
        global bid
        bid += 1
1699
        return ModelWorkerBatch(
1700
            bid=bid,
1701
1702
1703
1704
1705
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1706
            seq_lens_cpu=seq_lens_cpu,
1707
            seq_lens_sum=self.seq_lens_sum,
1708
1709
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1710
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1711
            global_num_tokens=self.global_num_tokens,
1712
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1713
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1714
1715
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1716
            extend_num_tokens=self.extend_num_tokens,
1717
1718
1719
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1720
            multimodal_inputs=self.multimodal_inputs,
1721
1722
1723
1724
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1725
            lora_paths=[req.lora_path for req in self.reqs],
1726
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1727
            input_embeds=self.input_embeds,
woodx's avatar
woodx committed
1728
            token_type_ids=self.token_type_ids,
1729
1730
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
1731
            hicache_consumer_index=self.hicache_consumer_index,
Lianmin Zheng's avatar
Lianmin Zheng committed
1732
            capture_hidden_mode=(
1733
                CaptureHiddenMode.FULL
1734
                if self.return_hidden_states
1735
1736
1737
1738
1739
1740
1741
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1742
            ),
1743
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1744
            launch_done=self.launch_done,
1745
1746
        )

1747
    def copy(self):
1748
        # Only contain fields that will be used by process_batch_result
1749
1750
        return ScheduleBatch(
            reqs=self.reqs,
1751
            model_config=self.model_config,
1752
            forward_mode=self.forward_mode,
1753
1754
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1755
            decoding_reqs=self.decoding_reqs,
1756
            spec_algorithm=self.spec_algorithm,
1757
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1758
1759
1760
1761
            global_num_tokens=self.global_num_tokens,
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
            is_extend_in_batch=self.is_extend_in_batch,
1762
1763
1764
1765
        )

    def __str__(self):
        return (
1766
            f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1767
1768
1769
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1770

1771
@dataclasses.dataclass
1772
class ModelWorkerBatch:
1773
1774
    # The batch id
    bid: int
1775
1776
1777
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1778
    input_ids: torch.Tensor
1779
1780
1781
1782
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1783
    # The indices of output tokens in the token_to_kv_pool_allocator
1784
1785
    out_cache_loc: torch.Tensor

1786
1787
    # The sequence length tensor on CPU
    seq_lens_cpu: Optional[torch.Tensor]
1788
1789
    seq_lens_sum: int

1790
1791
1792
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1793
    token_ids_logprobs: Optional[List[List[int]]]
1794

Ke Bao's avatar
Ke Bao committed
1795
1796
    # For DP attention
    global_num_tokens: Optional[List[int]]
1797
    global_num_tokens_for_logprob: Optional[List[int]]
1798
    can_run_dp_cuda_graph: bool
1799
1800
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1801

1802
    # For extend
1803
    extend_num_tokens: Optional[int]
1804
1805
1806
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1807
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1808
1809

    # For multimodal
Mick's avatar
Mick committed
1810
    multimodal_inputs: Optional[List[MultimodalInputs]]
1811

1812
1813
1814
1815
1816
1817
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1818
1819
1820
1821
1822
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1823

Rin Intachuen's avatar
Rin Intachuen committed
1824
1825
1826
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

woodx's avatar
woodx committed
1827
1828
1829
    # For corss-encoder model
    token_type_ids: Optional[torch.Tensor] = None

1830
    # Speculative decoding
1831
    spec_algorithm: SpeculativeAlgorithm = None
1832
1833
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1834
    capture_hidden_mode: CaptureHiddenMode = None
1835
    spec_num_draft_tokens: Optional[int] = None
1836
    hicache_consumer_index: int = 0
1837

1838
1839
1840
    # Overlap event
    launch_done: Optional[threading.Event] = None

1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1859
1860
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1877
1878


1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    if global_server_args_dict["attention_backend"] != "torch_native":
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1897
1898
1899
1900
1901
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result