schedule_batch.py 69.8 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from http import HTTPStatus
41
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
44
import torch
45
46
import triton
import triton.language as tl
47

Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.global_config import global_config
49
from sglang.srt.configs.model_config import ModelConfig
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Mick's avatar
Mick committed
56
from sglang.srt.layers.multimodal import gpu_tensor_hash
57
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
61
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
from sglang.srt.sampling.sampling_params import SamplingParams
64
from sglang.srt.server_args import ServerArgs
65
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
if TYPE_CHECKING:
68
69
70
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
71
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
72

73
74
# Put some global args for easy access
global_server_args_dict = {
75
    "attention_backend": ServerArgs.attention_backend,
76
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
77
    "deepep_mode": ServerArgs.deepep_mode,
78
    "device": ServerArgs.device,
79
    "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
80
    "disable_radix_cache": ServerArgs.disable_radix_cache,
81
82
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
    "enable_dp_attention": ServerArgs.enable_dp_attention,
83
    "enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
84
    "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
85
    "enable_ep_moe": ServerArgs.enable_ep_moe,
86
    "deepep_config": ServerArgs.deepep_config,
87
    "enable_nan_detection": ServerArgs.enable_nan_detection,
88
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
89
    "max_micro_batch_size": ServerArgs.max_micro_batch_size,
90
    "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
91
    "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
92
    "num_fused_shared_experts": ServerArgs.num_fused_shared_experts,
93
94
95
96
97
    "sampling_backend": ServerArgs.sampling_backend,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "torchao_config": ServerArgs.torchao_config,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
98
    "ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
99
100
}

Ying Sheng's avatar
Ying Sheng committed
101
102
103
logger = logging.getLogger(__name__)


104
105
106
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
    def to_json(self):
109
        raise NotImplementedError()
110
111
112


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
113
    def __init__(self, matched: Union[int, List[int]]):
114
115
116
        super().__init__()
        self.matched = matched

117
118
119
120
121
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
122
123


124
125
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
126
        super().__init__()
127
        self.matched = matched
128

129
130
131
132
133
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
134
135


136
137
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
138
        super().__init__()
139
        self.length = length
140

141
142
143
144
145
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
146
147
148


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
149
    def __init__(self, message=None, status_code=None, err_type=None):
150
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
151
        self.message = message or "Aborted"
152
153
        self.status_code = status_code
        self.err_type = err_type
154

155
156
157
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
158
            "message": self.message,
159
160
            "status_code": self.status_code,
            "err_type": self.err_type,
161
        }
162

Lianmin Zheng's avatar
Lianmin Zheng committed
163

Mick's avatar
Mick committed
164
165
166
167
168
169
170
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


171
@dataclasses.dataclass
Mick's avatar
Mick committed
172
173
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
174
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
175
    """
176

Mick's avatar
Mick committed
177
178
179
180
181
182
183
184
185
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
186
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
187
188

    # the real data, pixel_values or audio_features
189
190
191
192
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
    image_grid_thws: Union[torch.Tensor, np.ndarray] = None
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
193
194
195
196
197
198
199
200

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

201
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
202
    audio_feature_lens: Optional[List[torch.Tensor]] = None
203
    audio_offsets: Optional[List[Tuple[int, int]]] = None
Mick's avatar
Mick committed
204

205
206
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
207
208
209
210
211
212
213
214
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
215
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
216
217
        """

Mick's avatar
Mick committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
234
235
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
236
237
238
239
240
241
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

242
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
243
            if tensor.is_cuda:
244
245
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
246
247
248
249
250
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
251

Mick's avatar
Mick committed
252
253
        def hash_feature(f):
            if isinstance(f, list):
254
255
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
256
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
257
258
259
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
260
261
262
263
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
264

265
266
267
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
268
269
270
271
272
273
274
275
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
276
277
278
279
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
280
281
282
283

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
284
285
286
287
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
288
289

    def is_video(self):
290
291
292
293
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
294

295
296
297
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
298
299
300
301
    def validate(self):
        ...
        # TODO

302
303
304
305
306
307
308
309
310
311
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
312
313
314
315
316
317
318

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
319
    image_pad_len: Optional[list] = None
320
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
321

Yineng Zhang's avatar
Yineng Zhang committed
322
    # QWen2-VL related
323
    mrope_positions: Optional[torch.Tensor] = None
324
    mrope_position_delta: Optional[torch.Tensor] = None
325

Mick's avatar
Mick committed
326
    # image
Mick's avatar
Mick committed
327
    im_token_id: Optional[int] = None
328
329
330
331
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
332
333
334

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
335

Mick's avatar
Mick committed
336
    # audio
337
338
339
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
340

Liangsheng Yin's avatar
Liangsheng Yin committed
341
    @staticmethod
342
    def from_dict(obj: dict):
Mick's avatar
Mick committed
343
        ret = MultimodalInputs(
Mick's avatar
Mick committed
344
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
345
        )
346

Mick's avatar
Mick committed
347
        assert isinstance(ret.mm_items, list)
348
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
349
350
351

        for item in ret.mm_items:
            item.set_pad_value()
352
353

        optional_args = [
354
355
            "mrope_positions",
            "mrope_position_delta",
356
            "im_token_id",
Mick's avatar
Mick committed
357
358
359
360
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
361
362
            "audio_start_id",
            "audio_end_id",
363
            "audio_token_id",
364
365
366
367
368
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
369
370
        return ret

Mick's avatar
Mick committed
371
372
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
373
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
374
375
376

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
377
378
        return any(item.is_audio() for item in self.mm_items)

379
380
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
381
382

    def merge(self, other: MultimodalInputs):
383
384
385
        """
        merge image inputs when requests are being merged
        """
386

387
        # args needed to be merged
388
        optional_args = [
Mick's avatar
Mick committed
389
            "mm_items",
390
            "image_pad_len",
391
392
        ]
        for arg in optional_args:
393
394
395
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
396
397
398
399
400
401
402
403
404
405

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

406
407
408
409
410
411
412
413
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
414
415
416
417
418
419

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
420
        # other args would be kept intact
421

Liangsheng Yin's avatar
Liangsheng Yin committed
422

Lianmin Zheng's avatar
Lianmin Zheng committed
423
class Req:
424
    """The input and output status of a request."""
425

426
427
428
429
430
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
431
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
434
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
435
        stream: bool = False,
436
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
437
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
438
        input_embeds: Optional[List[List[float]]] = None,
439
        session_id: Optional[str] = None,
440
        custom_logit_processor: Optional[str] = None,
441
        return_hidden_states: bool = False,
442
        eos_token_ids: Optional[Set[int]] = None,
443
        bootstrap_host: Optional[str] = None,
444
        bootstrap_port: Optional[int] = None,
445
        bootstrap_room: Optional[int] = None,
446
    ):
447
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
448
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
449
        self.origin_input_text = origin_input_text
450
451
452
453
454
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
455
        self.origin_input_ids = origin_input_ids
456
457
458
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
459
        self.fill_ids = None
460
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
461
        self.input_embeds = input_embeds
462

Lianmin Zheng's avatar
Lianmin Zheng committed
463
        # Sampling info
464
465
466
467
468
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
469
        self.sampling_params = sampling_params
470
        self.custom_logit_processor = custom_logit_processor
471
        self.return_hidden_states = return_hidden_states
472
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
473

474
        # Memory pool info
475
        self.req_pool_idx: Optional[int] = None
476

477
478
479
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
480
481
        # Whether this request has finished output
        self.finished_output = None
482
483
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
484
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
486
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
487
        self.stream = stream
488
        self.eos_token_ids = eos_token_ids
489

490
        # For incremental decoding
491
492
493
494
495
496
497
498
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
499
500
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
501
        self.decoded_text = ""
502

503
        # For multimodal inputs
Mick's avatar
Mick committed
504
        self.multimodal_inputs: Optional[MultimodalInputs] = None
505

506
        # Prefix info
507
        # The indices to kv cache for the shared prefix.
508
        self.prefix_indices = []
509
        # Number of tokens to run prefill.
510
        self.extend_input_len = 0
511
512
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
513
        self.last_node = None
514
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
515

516
517
518
519
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
520

521
522
523
        # For retraction
        self.is_retracted = False

524
525
526
527
528
529
530
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

531
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
532
        self.return_logprob = return_logprob
533
        # Start index to compute logprob from.
534
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
535
        self.top_logprobs_num = top_logprobs_num
536
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
537
538
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
539

540
        # Logprobs (return values)
541
542
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
543
544
545
546
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
547
548
549
550
551
552
553
554
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556

        if return_logprob:
557
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
560
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
563
564
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
568
569
570
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
571
        self.hidden_states: List[List[float]] = []
572

573
        # Embedding (return values)
574
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
575

576
        # Constrained decoding
577
        self.grammar: Optional[BaseGrammarObject] = None
578
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
579

580
        # The number of cached tokens that were already cached in the KV cache
581
        self.cached_tokens = 0
582
        self.already_computed = 0
583

584
585
586
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
587
588
589
590
591
592

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
593

Byron Hsu's avatar
Byron Hsu committed
594
        # For disaggregation
595
        self.bootstrap_host: str = bootstrap_host
596
        self.bootstrap_port: Optional[int] = bootstrap_port
597
        self.bootstrap_room: Optional[int] = bootstrap_room
598
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
599
600
601
602
603
604
605
606

        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

607
608
609
610
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
611
        self.metadata_buffer_index: int = -1
612

613
614
615
616
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

617
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
618
619
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
620
        else:
Mick's avatar
Mick committed
621
            self.multimodal_inputs.merge(image_inputs)
622

623
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
624
        # Whether request reached finished condition
625
626
        return self.finished_reason is not None

627
628
629
630
631
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
632
        self.fill_ids = self.origin_input_ids + self.output_ids
633
        if tree_cache is not None:
634
            # tree cache is None if the prefix is not computed with tree cache.
635
636
637
638
639
640
641
642
643
644
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
645
646
647
648
649
650
651
652
        elif enable_hierarchical_cache:
            # in case last_node is evicted during scheduling, we need to update the prefix_indices
            while self.last_node.evicted:
                self.prefix_indices = self.prefix_indices[
                    : -len(self.last_node.host_value)
                ]
                self.last_node = self.last_node.parent

653
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
654

655
    def adjust_max_prefix_ids(self):
656
657
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
658
659
660
661

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
662
663
664
665
666

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

667
        if self.return_logprob:
668
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
669

670
        max_prefix_len = max(max_prefix_len, 0)
671
        return self.fill_ids[:max_prefix_len]
672

Liangsheng Yin's avatar
Liangsheng Yin committed
673
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
674
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
675
676
677
678
679
680
681
682
683
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
684
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
685

686
    def check_finished(self):
687
        if self.finished():
688
689
            return

690
        if self.to_abort:
691
692
693
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
694
695
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
696
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
697
698
699
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
700
701
            return

702
703
704
705
706
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

707
        last_token_id = self.output_ids[-1]
708

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
726

727
        # Check stop strings
728
729
730
731
732
733
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
734
                if stop_str in tail_str or stop_str in self.decoded_text:
735
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
736
737
                    return

738
739
740
741
742
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
743
744
745
746
747
748
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
749
        self.already_computed = 0
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
754
755
756
757
758
759
760
761
762
763
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

764
765
766
767
768
769
770
771
772
773
774
775
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

776
777
778
779
780
781
782
783
784
785
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
786
    def __repr__(self):
787
        return (
788
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
789
790
791
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
792
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
793
794


Lianmin Zheng's avatar
Lianmin Zheng committed
795
# Batch id
796
797
798
bid = 0


799
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
800
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
801
    """Store all information of a batch on the scheduler."""
802

803
    # Request, memory pool, and cache
804
    reqs: List[Req]
805
    req_to_token_pool: ReqToTokenPool = None
806
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
807
    tree_cache: BasePrefixCache = None
808

809
    # Batch configs
810
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
811
    forward_mode: ForwardMode = None
812
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
813
814
815
816
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
817

818
819
820
    # Events
    launch_done: Optional[threading.Event] = None

821
822
823
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

824
    # Sampling info
825
    sampling_info: SamplingBatchInfo = None
826
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
827

828
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
829
    input_ids: torch.Tensor = None  # shape: [b], int64
830
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
831
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
832
    seq_lens: torch.Tensor = None  # shape: [b], int64
833
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
834
835
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
836

837
838
839
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

840
841
842
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
843
844
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
845
    global_num_tokens_for_logprob: Optional[List[int]] = None
846
    can_run_dp_cuda_graph: bool = False
847
848
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
849

850
    # For processing logprobs
851
    return_logprob: bool = False
852
    top_logprobs_nums: Optional[List[int]] = None
853
    token_ids_logprobs: Optional[List[List[int]]] = None
854

Lianmin Zheng's avatar
Lianmin Zheng committed
855
856
857
858
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

859
860
861
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
862
    extend_num_tokens: Optional[int] = None
863
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
864
    extend_logprob_start_lens: List[int] = None
865
866
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
867

Lianmin Zheng's avatar
Lianmin Zheng committed
868
    # For encoder-decoder architectures
869
870
871
872
873
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

874
875
876
    # Stream
    has_stream: bool = False

877
878
    # Has grammar
    has_grammar: bool = False
879

880
    # Device
881
882
    device: str = "cuda"

883
    # Speculative decoding
884
    spec_algorithm: SpeculativeAlgorithm = None
885
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
886

887
888
889
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

890
891
892
    # Whether to return hidden states
    return_hidden_states: bool = False

893
    @classmethod
894
895
    def init_new(
        cls,
896
        reqs: List[Req],
897
        req_to_token_pool: ReqToTokenPool,
898
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
899
900
901
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
902
        spec_algorithm: SpeculativeAlgorithm,
903
        enable_custom_logit_processor: bool,
904
        chunked_req: Optional[Req] = None,
905
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907
        return_logprob = any(req.return_logprob for req in reqs)

908
909
910
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
911
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
912
            tree_cache=tree_cache,
913
            model_config=model_config,
914
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
915
            return_logprob=return_logprob,
916
            has_stream=any(req.stream for req in reqs),
917
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
918
            device=req_to_token_pool.device,
919
            spec_algorithm=spec_algorithm,
920
            enable_custom_logit_processor=enable_custom_logit_processor,
921
            return_hidden_states=any(req.return_hidden_states for req in reqs),
922
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
923
924
        )

925
    def batch_size(self):
926
        return len(self.reqs)
927

Lianmin Zheng's avatar
Lianmin Zheng committed
928
929
930
    def is_empty(self):
        return len(self.reqs) == 0

931
    def alloc_req_slots(self, num_reqs: int):
932
933
934
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
935
936
937
938
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
939
940
941
            )
        return req_pool_indices

942
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
943
944
945
946
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

947
948
949
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

950
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
953
954
955
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
956
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
957
958
959
960
961
962
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

963
964
965
966
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
967
968
969
970
971
972
973

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
974
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
975
976
977
978
979
980
981
982
983
984
985
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
986

987
988
989
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
990
991
992
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
993
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
994
995
996
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
997
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
998
999
1000
1001
1002
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1003
1004
1005
1006
1007

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1008
1009
1010
1011
1012

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1013
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
    ):
1015
1016
1017
1018
1019
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1022
                )
1023

1024
1025
1026
1027
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1028
1029
1030
1031
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1032
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1033
1034
1035
1036
1037
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1038
1039
1040
1041
1042

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1043

1044
1045
1046
1047
1048
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1049
            im = req.multimodal_inputs
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1061
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1074
                # NOTE: the encoder part should be considered as a whole
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1092
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1093
1094
            self.device, non_blocking=True
        )
1095
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1096
1097
1098
1099
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1100
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1101
1102
1103
1104
1105
1106
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1107
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1108
1109
1110
1111
1112
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1113
1114
1115
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1116

1117
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1118
1119
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1120
        # Allocate req slots
1121
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1122
1123
1124
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
        reqs = self.reqs
1126
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1127
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1131

Lianmin Zheng's avatar
Lianmin Zheng committed
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1145

Lianmin Zheng's avatar
Lianmin Zheng committed
1146
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1147
        input_embeds = []
1148
        extend_input_logprob_token_ids = []
1149
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1150

Lianmin Zheng's avatar
Lianmin Zheng committed
1151
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1152
            req.req_pool_idx = req_pool_indices[i]
1153
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1154

1155
            if pre_len > 0:
1156
1157
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1158
                )
1159

Rin Intachuen's avatar
Rin Intachuen committed
1160
1161
1162
1163
1164
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1165
1166
            multimodal_inputs.append(req.multimodal_inputs)

1167
1168
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1169
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1170

1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1225

Lianmin Zheng's avatar
Lianmin Zheng committed
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1239
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1240
1241
1242
1243
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1244
1245
1246
1247
1248
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
1259
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1260

1261
1262
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1263
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1264

1265
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1266
1267
1268
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1269
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1270

1271
        # Write to req_to_token_pool
1272
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
1274
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1275
1276
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1277
1278
1279
1280
1281
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1282
1283
1284
1285
1286
1287
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1288
1289
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1290
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
                pt += extend_lens[i]
1292

1293
1294
1295
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1296
        # Build sampling info
1297
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1298
1299
            self,
            self.model_config.vocab_size,
1300
        )
1301

1302
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1303
        self.forward_mode = ForwardMode.MIXED
1304
        running_bs = running_batch.batch_size()
1305
1306
1307
1308
1309

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1310
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1311
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1312

1313
        self.merge_batch(running_batch)
1314
1315
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1316

1317
1318
1319
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1320
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1321
        self.prefix_lens.extend(
1322
            [
1323
                len(r.origin_input_ids) + len(r.output_ids) + delta
1324
1325
1326
                for r in running_batch.reqs
            ]
        )
1327
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1328
1329
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1330
        self.extend_logprob_start_lens.extend([0] * running_bs)
1331

1332
1333
1334
1335
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1336
1337
1338
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
        return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1339

1340
1341
1342
1343
1344
1345
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1346

1347
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1348
1349
            return True

1350
1351
1352
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1353

1354
    def retract_decode(self, server_args: ServerArgs):
1355
        """Retract the decoding requests when there is not enough memory."""
1356
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1357
1358

        # TODO(lsyin): improve retraction policy for radix cache
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1384

Lianmin Zheng's avatar
Lianmin Zheng committed
1385
1386
1387
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1388
        while (
1389
            self.token_to_kv_pool_allocator.available_size()
1390
            < get_required_tokens(len(sorted_indices))
1391
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1392
1393
1394
1395
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1396
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1397
1398
1399
                ), "No space left for only one request"
                break

1400
            first_iter = False
1401
1402
1403
1404
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1405
1406
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1407
1408
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1409
                ]
1410
                self.token_to_kv_pool_allocator.free(token_indices)
1411
                self.req_to_token_pool.free(req.req_pool_idx)
1412
1413
            else:
                # TODO: apply more fine-grained retraction
1414
                last_uncached_pos = (
1415
1416
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1417
1418
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1419
                ]
1420
                self.token_to_kv_pool_allocator.free(token_indices)
1421
                self.req_to_token_pool.free(req.req_pool_idx)
1422
1423
1424
1425
1426
1427
1428

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1429
                    - self.token_to_kv_pool_allocator.available_size()
1430
1431
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1432
                self.tree_cache.evict(residual_size)
1433

1434
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1435

1436
        self.filter_batch(keep_indices=sorted_indices)
1437

Liangsheng Yin's avatar
Liangsheng Yin committed
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1448

1449
1450
1451
1452
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1453
1454
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1455
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1456
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1457
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1458
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1459
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1460
        self.extend_num_tokens = 0
1461
1462
1463
1464
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1465

1466
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1467
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1468
1469
        bs = len(self.reqs)

1470
        if self.spec_algorithm.is_eagle():
1471
1472
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1473
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1474

1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1498
        # Update fields
1499
1500
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1501

1502
1503
1504
1505
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1506
            locs = self.seq_lens.clone()
1507

1508
        if self.enable_overlap:
1509
1510
1511
1512
1513
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1514
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1515

Lianmin Zheng's avatar
Lianmin Zheng committed
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1531
1532
    def filter_batch(
        self,
1533
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1534
1535
1536
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1537
1538
1539
1540
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1541
1542
1543
            keep_indices = [
                i
                for i in range(len(self.reqs))
1544
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1545
                and self.reqs[i] not in chunked_req_to_exclude
1546
1547
1548
            ]

        if keep_indices is None or len(keep_indices) == 0:
1549
1550
1551
1552
            # Filter out all requests
            self.reqs = []
            return

1553
        if len(keep_indices) == len(self.reqs):
1554
1555
1556
            # No need to filter
            return

1557
1558
1559
1560
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1561
        if self.model_config.is_encoder_decoder:
1562
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1563
1564
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1565
        self.reqs = [self.reqs[i] for i in keep_indices]
1566
1567
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1568
1569
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1570
        self.out_cache_loc = None
1571
        self.seq_lens_sum = self.seq_lens.sum().item()
1572
        self.output_ids = self.output_ids[keep_indices_device]
1573
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1574
        if self.return_logprob:
1575
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1576
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1577
1578
        else:
            self.top_logprobs_nums = None
1579
            self.token_ids_logprobs = None
1580

1581
        self.has_stream = any(req.stream for req in self.reqs)
1582
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1583

1584
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1585
        if self.spec_info:
1586
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1587

1588
    def merge_batch(self, other: "ScheduleBatch"):
1589
1590
1591
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1592
        self.sampling_info.merge_batch(other.sampling_info)
1593

1594
1595
1596
1597
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1598
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1599
1600
            [self.req_pool_indices, other.req_pool_indices]
        )
1601
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1602
        self.out_cache_loc = None
1603
        self.seq_lens_sum += other.seq_lens_sum
1604
        if self.output_ids is not None:
1605
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1606
1607
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1608
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1609
1610
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1611
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1612
1613
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1614
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1615
        self.reqs.extend(other.reqs)
1616
1617
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1618

1619
1620
1621
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1622
        self.return_hidden_states |= other.return_hidden_states
1623

1624
1625
1626
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1627
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1628
        if self.forward_mode.is_decode_or_idle():
1629
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1630
1631
1632
1633
1634
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1635
1636
        # Create seq_lens_cpu when needed
        if (
1637
1638
1639
1640
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1641
            or global_server_args_dict["attention_backend"] == "flashmla"
1642
            or global_server_args_dict["attention_backend"] == "fa3"
1643
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1644
            or global_server_args_dict["enable_two_batch_overlap"]
1645
1646
1647
1648
1649
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1650
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1651
1652
1653
1654
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1655

1656
1657
        global bid
        bid += 1
1658
        return ModelWorkerBatch(
1659
            bid=bid,
1660
1661
1662
1663
1664
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1665
            seq_lens_sum=self.seq_lens_sum,
1666
1667
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1668
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1669
            global_num_tokens=self.global_num_tokens,
1670
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1671
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1672
1673
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1674
            seq_lens_cpu=seq_lens_cpu,
1675
            extend_num_tokens=self.extend_num_tokens,
1676
1677
1678
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1679
            multimodal_inputs=self.multimodal_inputs,
1680
1681
1682
1683
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1684
            lora_paths=[req.lora_path for req in self.reqs],
1685
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1686
            input_embeds=self.input_embeds,
1687
1688
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1689
            capture_hidden_mode=(
1690
                CaptureHiddenMode.FULL
1691
                if self.return_hidden_states
1692
1693
1694
1695
1696
1697
1698
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1699
            ),
1700
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1701
            launch_done=self.launch_done,
1702
1703
        )

1704
    def copy(self):
1705
        # Only contain fields that will be used by process_batch_result
1706
1707
        return ScheduleBatch(
            reqs=self.reqs,
1708
            model_config=self.model_config,
1709
            forward_mode=self.forward_mode,
1710
1711
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1712
            decoding_reqs=self.decoding_reqs,
1713
            spec_algorithm=self.spec_algorithm,
1714
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1715
1716
1717
1718
1719
1720
1721
1722
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1723

1724
@dataclasses.dataclass
1725
class ModelWorkerBatch:
1726
1727
    # The batch id
    bid: int
1728
1729
1730
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1731
    input_ids: torch.Tensor
1732
1733
1734
1735
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1736
    seq_lens_cpu: Optional[torch.Tensor]
1737
    # The indices of output tokens in the token_to_kv_pool_allocator
1738
1739
    out_cache_loc: torch.Tensor

1740
1741
1742
    # The sum of all sequence lengths
    seq_lens_sum: int

1743
1744
1745
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1746
    token_ids_logprobs: Optional[List[List[int]]]
1747

Ke Bao's avatar
Ke Bao committed
1748
1749
    # For DP attention
    global_num_tokens: Optional[List[int]]
1750
    global_num_tokens_for_logprob: Optional[List[int]]
1751
    can_run_dp_cuda_graph: bool
1752
1753
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1754

1755
    # For extend
1756
    extend_num_tokens: Optional[int]
1757
1758
1759
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1760
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1761
1762

    # For multimodal
Mick's avatar
Mick committed
1763
    multimodal_inputs: Optional[List[MultimodalInputs]]
1764

1765
1766
1767
1768
1769
1770
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1771
1772
1773
1774
1775
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1776

Rin Intachuen's avatar
Rin Intachuen committed
1777
1778
1779
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1780
    # Speculative decoding
1781
    spec_algorithm: SpeculativeAlgorithm = None
1782
1783
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1784
    capture_hidden_mode: CaptureHiddenMode = None
1785

1786
1787
1788
    # Overlap event
    launch_done: Optional[threading.Event] = None

1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1807
1808
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1825
1826


1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    if global_server_args_dict["attention_backend"] != "torch_native":
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1845
1846
1847
1848
1849
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result