schedule_batch.py 69.9 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from http import HTTPStatus
41
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
42

43
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
44
import torch
45
46
import triton
import triton.language as tl
47

Liangsheng Yin's avatar
Liangsheng Yin committed
48
from sglang.global_config import global_config
49
from sglang.srt.configs.model_config import ModelConfig
50
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
51
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
52
53
54
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
55
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
Mick's avatar
Mick committed
56
from sglang.srt.layers.multimodal import gpu_tensor_hash
57
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
61
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
63
from sglang.srt.sampling.sampling_params import SamplingParams
64
from sglang.srt.server_args import ServerArgs
65
from sglang.srt.utils import flatten_nested_list, support_triton
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
if TYPE_CHECKING:
68
69
70
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
71
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
72

73
74
# Put some global args for easy access
global_server_args_dict = {
75
    "attention_backend": ServerArgs.attention_backend,
76
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
77
    "deepep_mode": ServerArgs.deepep_mode,
78
    "device": ServerArgs.device,
79
    "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
80
    "disable_radix_cache": ServerArgs.disable_radix_cache,
81
82
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
    "enable_dp_attention": ServerArgs.enable_dp_attention,
83
    "enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
84
    "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
85
    "enable_ep_moe": ServerArgs.enable_ep_moe,
86
    "deepep_config": ServerArgs.deepep_config,
87
    "enable_nan_detection": ServerArgs.enable_nan_detection,
88
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
89
    "max_micro_batch_size": ServerArgs.max_micro_batch_size,
90
    "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
91
    "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
92
    "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
93
94
95
96
97
    "sampling_backend": ServerArgs.sampling_backend,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "torchao_config": ServerArgs.torchao_config,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
98
    "ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
99
100
}

Ying Sheng's avatar
Ying Sheng committed
101
102
103
logger = logging.getLogger(__name__)


104
105
106
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
107

108
    def to_json(self):
109
        raise NotImplementedError()
110
111
112


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
113
    def __init__(self, matched: Union[int, List[int]]):
114
115
116
        super().__init__()
        self.matched = matched

117
118
119
120
121
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
122
123


124
125
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
126
        super().__init__()
127
        self.matched = matched
128

129
130
131
132
133
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
134
135


136
137
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
138
        super().__init__()
139
        self.length = length
140

141
142
143
144
145
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
146
147
148


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
149
    def __init__(self, message=None, status_code=None, err_type=None):
150
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
151
        self.message = message or "Aborted"
152
153
        self.status_code = status_code
        self.err_type = err_type
154

155
156
157
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
158
            "message": self.message,
159
160
            "status_code": self.status_code,
            "err_type": self.err_type,
161
        }
162

Lianmin Zheng's avatar
Lianmin Zheng committed
163

Mick's avatar
Mick committed
164
165
166
167
168
169
170
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


171
@dataclasses.dataclass
Mick's avatar
Mick committed
172
173
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
174
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
175
    """
176

Mick's avatar
Mick committed
177
178
179
180
181
182
183
184
185
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
186
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
187
188

    # the real data, pixel_values or audio_features
189
190
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
191
    image_grid_thw: Union[torch.Tensor, np.ndarray] = None
192
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
193
194
195
196
197
198
199
200

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

201
202
203
    # kimi-vl related
    image_grid_hws: Optional[List[torch.Tensor]] = None

204
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
205
    audio_feature_lens: Optional[List[torch.Tensor]] = None
206
    audio_offsets: Optional[List[Tuple[int, int]]] = None
Mick's avatar
Mick committed
207

208
209
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
210
211
212
213
214
215
216
217
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
218
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
219
220
        """

Mick's avatar
Mick committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
237
238
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
239
240
241
242
243
244
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

245
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
246
            if tensor.is_cuda:
247
248
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
249
250
251
252
253
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
254

Mick's avatar
Mick committed
255
256
        def hash_feature(f):
            if isinstance(f, list):
257
258
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
259
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
260
261
262
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
263
264
265
266
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
267

268
269
270
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
271
272
273
274
275
276
277
278
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
279
280
281
282
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
283
284
285
286

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
287
288
289
290
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
291
292

    def is_video(self):
293
294
295
296
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
297

298
299
300
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
301
302
303
304
    def validate(self):
        ...
        # TODO

305
306
307
308
309
310
311
312
313
314
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
315
316
317
318
319
320
321

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
322
    image_pad_len: Optional[list] = None
323
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
324

Yineng Zhang's avatar
Yineng Zhang committed
325
    # QWen2-VL related
326
    mrope_positions: Optional[torch.Tensor] = None
327
    mrope_position_delta: Optional[torch.Tensor] = None
328

Mick's avatar
Mick committed
329
    # image
Mick's avatar
Mick committed
330
    im_token_id: Optional[int] = None
331
332
333
334
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
335
336
337

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
338

Mick's avatar
Mick committed
339
    # audio
340
341
342
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
343

Liangsheng Yin's avatar
Liangsheng Yin committed
344
    @staticmethod
345
    def from_dict(obj: dict):
Mick's avatar
Mick committed
346
        ret = MultimodalInputs(
Mick's avatar
Mick committed
347
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
348
        )
349

Mick's avatar
Mick committed
350
        assert isinstance(ret.mm_items, list)
351
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
352
353
354

        for item in ret.mm_items:
            item.set_pad_value()
355
356

        optional_args = [
357
358
            "mrope_positions",
            "mrope_position_delta",
359
            "im_token_id",
Mick's avatar
Mick committed
360
361
362
363
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
364
365
            "audio_start_id",
            "audio_end_id",
366
            "audio_token_id",
367
368
369
370
371
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
372
373
        return ret

Mick's avatar
Mick committed
374
375
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
376
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
377
378
379

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
380
381
        return any(item.is_audio() for item in self.mm_items)

382
383
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
384
385

    def merge(self, other: MultimodalInputs):
386
387
388
        """
        merge image inputs when requests are being merged
        """
389

390
        # args needed to be merged
391
        optional_args = [
Mick's avatar
Mick committed
392
            "mm_items",
393
            "image_pad_len",
394
395
        ]
        for arg in optional_args:
396
397
398
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
399
400
401
402
403
404
405
406
407
408

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

409
410
411
412
413
414
415
416
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
417
418
419
420
421
422

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
423
        # other args would be kept intact
424

Liangsheng Yin's avatar
Liangsheng Yin committed
425

Lianmin Zheng's avatar
Lianmin Zheng committed
426
class Req:
427
    """The input and output status of a request."""
428

429
430
431
432
433
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
434
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
437
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
438
        stream: bool = False,
439
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
440
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
441
        input_embeds: Optional[List[List[float]]] = None,
442
        session_id: Optional[str] = None,
443
        custom_logit_processor: Optional[str] = None,
444
        return_hidden_states: bool = False,
445
        eos_token_ids: Optional[Set[int]] = None,
446
        bootstrap_host: Optional[str] = None,
447
        bootstrap_port: Optional[int] = None,
448
        bootstrap_room: Optional[int] = None,
449
    ):
450
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
451
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
452
        self.origin_input_text = origin_input_text
453
454
455
456
457
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
458
        self.origin_input_ids = origin_input_ids
459
460
461
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
462
        self.fill_ids = None
463
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        self.input_embeds = input_embeds
465

Lianmin Zheng's avatar
Lianmin Zheng committed
466
        # Sampling info
467
468
469
470
471
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
472
        self.sampling_params = sampling_params
473
        self.custom_logit_processor = custom_logit_processor
474
        self.return_hidden_states = return_hidden_states
475
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
476

477
        # Memory pool info
478
        self.req_pool_idx: Optional[int] = None
479

480
481
482
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
        # Whether this request has finished output
        self.finished_output = None
485
486
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
487
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
488
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
489
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
490
        self.stream = stream
491
        self.eos_token_ids = eos_token_ids
492

493
        # For incremental decoding
494
495
496
497
498
499
500
501
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
502
503
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
504
        self.decoded_text = ""
505

506
        # For multimodal inputs
Mick's avatar
Mick committed
507
        self.multimodal_inputs: Optional[MultimodalInputs] = None
508

509
        # Prefix info
510
        # The indices to kv cache for the shared prefix.
511
        self.prefix_indices = []
512
        # Number of tokens to run prefill.
513
        self.extend_input_len = 0
514
515
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
516
        self.last_node = None
517
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
518

519
520
521
522
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
523

524
525
526
        # For retraction
        self.is_retracted = False

527
528
529
530
531
532
533
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

534
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
535
        self.return_logprob = return_logprob
536
        # Start index to compute logprob from.
537
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
538
        self.top_logprobs_num = top_logprobs_num
539
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
540
541
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
542

543
        # Logprobs (return values)
544
545
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
546
547
548
549
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
550
551
552
553
554
555
556
557
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
558
559

        if return_logprob:
560
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
563
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
564
565
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
566
567
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
568
569
570
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
571
572
573
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
574
        self.hidden_states: List[List[float]] = []
575

576
        # Embedding (return values)
577
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
578

579
        # Constrained decoding
580
        self.grammar: Optional[BaseGrammarObject] = None
581
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
582

583
        # The number of cached tokens that were already cached in the KV cache
584
        self.cached_tokens = 0
585
        self.already_computed = 0
586

587
588
589
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
590
591
592
593
594
595

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
596

Byron Hsu's avatar
Byron Hsu committed
597
        # For disaggregation
598
        self.bootstrap_host: str = bootstrap_host
599
        self.bootstrap_port: Optional[int] = bootstrap_port
600
        self.bootstrap_room: Optional[int] = bootstrap_room
601
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
602
603
604
605
606
607
608
609

        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

610
611
612
613
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
614
        self.metadata_buffer_index: int = -1
615

616
617
618
619
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

620
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
621
622
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
623
        else:
Mick's avatar
Mick committed
624
            self.multimodal_inputs.merge(image_inputs)
625

626
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
627
        # Whether request reached finished condition
628
629
        return self.finished_reason is not None

630
631
632
633
634
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
635
        self.fill_ids = self.origin_input_ids + self.output_ids
636
        if tree_cache is not None:
637
            # tree cache is None if the prefix is not computed with tree cache.
638
639
640
641
642
643
644
645
646
647
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
648
649
650
651
652
653
654
655
        elif enable_hierarchical_cache:
            # in case last_node is evicted during scheduling, we need to update the prefix_indices
            while self.last_node.evicted:
                self.prefix_indices = self.prefix_indices[
                    : -len(self.last_node.host_value)
                ]
                self.last_node = self.last_node.parent

656
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
657

658
    def adjust_max_prefix_ids(self):
659
660
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
661
662
663
664

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
665
666
667
668
669

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

670
        if self.return_logprob:
671
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
672

673
        max_prefix_len = max(max_prefix_len, 0)
674
        return self.fill_ids[:max_prefix_len]
675

Liangsheng Yin's avatar
Liangsheng Yin committed
676
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
677
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
678
679
680
681
682
683
684
685
686
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
687
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
688

689
    def check_finished(self):
690
        if self.finished():
691
692
            return

693
        if self.to_abort:
694
695
696
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
697
698
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
699
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
700
701
702
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
703
704
            return

705
706
707
708
709
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

710
        last_token_id = self.output_ids[-1]
711

712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
729

730
        # Check stop strings
731
732
733
734
735
736
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
737
                if stop_str in tail_str or stop_str in self.decoded_text:
738
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
739
740
                    return

741
742
743
744
745
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
746
747
748
749
750
751
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
752
        self.already_computed = 0
753

Lianmin Zheng's avatar
Lianmin Zheng committed
754
755
756
757
758
759
760
761
762
763
764
765
766
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

767
768
769
770
771
772
773
774
775
776
777
778
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

779
780
781
782
783
784
785
786
787
788
    def set_finish_with_abort(self, error_msg: str):
        if get_tensor_model_parallel_rank() == 0:
            logger.error(f"{error_msg}, {self.rid=}")
        self.multimodal_inputs = None
        self.grammar = None
        self.origin_input_ids = [0]  # set it to one token to skip the long prefill
        self.finished_reason = FINISH_ABORT(
            error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
789
    def __repr__(self):
790
        return (
791
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
794
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
795
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797


Lianmin Zheng's avatar
Lianmin Zheng committed
798
# Batch id
799
800
801
bid = 0


802
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
803
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
804
    """Store all information of a batch on the scheduler."""
805

806
    # Request, memory pool, and cache
807
    reqs: List[Req]
808
    req_to_token_pool: ReqToTokenPool = None
809
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
810
    tree_cache: BasePrefixCache = None
811

812
    # Batch configs
813
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
814
    forward_mode: ForwardMode = None
815
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
816
817
818
819
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
820

821
822
823
    # Events
    launch_done: Optional[threading.Event] = None

824
825
826
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

827
    # Sampling info
828
    sampling_info: SamplingBatchInfo = None
829
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
830

831
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
832
    input_ids: torch.Tensor = None  # shape: [b], int64
833
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
834
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
835
    seq_lens: torch.Tensor = None  # shape: [b], int64
836
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
837
838
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
839

840
841
842
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

843
844
845
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
846
847
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
848
    global_num_tokens_for_logprob: Optional[List[int]] = None
849
    can_run_dp_cuda_graph: bool = False
850
851
    tbo_split_seq_index: Optional[int] = None
    global_forward_mode: Optional[ForwardMode] = None
Ke Bao's avatar
Ke Bao committed
852

853
    # For processing logprobs
854
    return_logprob: bool = False
855
    top_logprobs_nums: Optional[List[int]] = None
856
    token_ids_logprobs: Optional[List[List[int]]] = None
857

Lianmin Zheng's avatar
Lianmin Zheng committed
858
859
860
861
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

862
863
864
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
865
    extend_num_tokens: Optional[int] = None
866
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
867
    extend_logprob_start_lens: List[int] = None
868
869
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
870

Lianmin Zheng's avatar
Lianmin Zheng committed
871
    # For encoder-decoder architectures
872
873
874
875
876
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

877
878
879
    # Stream
    has_stream: bool = False

880
881
    # Has grammar
    has_grammar: bool = False
882

883
    # Device
884
885
    device: str = "cuda"

886
    # Speculative decoding
887
    spec_algorithm: SpeculativeAlgorithm = None
888
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
889

890
891
892
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

893
894
895
    # Whether to return hidden states
    return_hidden_states: bool = False

896
    @classmethod
897
898
    def init_new(
        cls,
899
        reqs: List[Req],
900
        req_to_token_pool: ReqToTokenPool,
901
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
902
903
904
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
905
        spec_algorithm: SpeculativeAlgorithm,
906
        enable_custom_logit_processor: bool,
907
        chunked_req: Optional[Req] = None,
908
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
909
910
        return_logprob = any(req.return_logprob for req in reqs)

911
912
913
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
914
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
915
            tree_cache=tree_cache,
916
            model_config=model_config,
917
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
918
            return_logprob=return_logprob,
919
            has_stream=any(req.stream for req in reqs),
920
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
921
            device=req_to_token_pool.device,
922
            spec_algorithm=spec_algorithm,
923
            enable_custom_logit_processor=enable_custom_logit_processor,
924
            return_hidden_states=any(req.return_hidden_states for req in reqs),
925
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
926
927
        )

928
    def batch_size(self):
929
        return len(self.reqs)
930

Lianmin Zheng's avatar
Lianmin Zheng committed
931
932
933
    def is_empty(self):
        return len(self.reqs) == 0

934
    def alloc_req_slots(self, num_reqs: int):
935
936
937
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
938
939
940
941
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
942
943
944
            )
        return req_pool_indices

945
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
946
947
948
949
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

950
951
952
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

953
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
954
955
956
957
958
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
959
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
960
961
962
963
964
965
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

966
967
968
969
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
970
971
972
973
974
975
976

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
977
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
978
979
980
981
982
983
984
985
986
987
988
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
989

990
991
992
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
993
994
995
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
996
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
997
998
999
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
1000
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002
1003
1004
1005
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1006
1007
1008
1009
1010

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
1013
1014
1015

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1016
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
    ):
1018
1019
1020
1021
1022
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1023
1024
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1025
                )
1026

1027
1028
1029
1030
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1031
1032
1033
1034
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1035
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037
1038
1039
1040
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1041
1042
1043
1044
1045

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1046

1047
1048
1049
1050
1051
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1052
            im = req.multimodal_inputs
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1064
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1077
                # NOTE: the encoder part should be considered as a whole
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1096
1097
            self.device, non_blocking=True
        )
1098
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1099
1100
1101
1102
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1104
1105
1106
1107
1108
1109
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1110
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1111
1112
1113
1114
1115
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

1116
1117
1118
        assert (
            len(self.out_cache_loc) == self.extend_num_tokens
        ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
1119

1120
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1121
1122
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1123
        # Allocate req slots
1124
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
1126
1127
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
        reqs = self.reqs
1129
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1130
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
1132
1133
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1134

Lianmin Zheng's avatar
Lianmin Zheng committed
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1148

Lianmin Zheng's avatar
Lianmin Zheng committed
1149
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1150
        input_embeds = []
1151
        extend_input_logprob_token_ids = []
1152
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1153

Lianmin Zheng's avatar
Lianmin Zheng committed
1154
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1155
            req.req_pool_idx = req_pool_indices[i]
1156
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1157

1158
            if pre_len > 0:
1159
1160
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1161
                )
1162

Rin Intachuen's avatar
Rin Intachuen committed
1163
1164
1165
1166
1167
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1168
1169
            multimodal_inputs.append(req.multimodal_inputs)

1170
1171
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1172
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1173

1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1228

Lianmin Zheng's avatar
Lianmin Zheng committed
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1242
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1243
1244
1245
1246
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1247
1248
1249
1250
1251
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
1262
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1263

1264
1265
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1266
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1267

1268
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
1270
1271
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1272
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1273

1274
        # Write to req_to_token_pool
1275
        if support_triton(global_server_args_dict.get("attention_backend")):
Lianmin Zheng's avatar
Lianmin Zheng committed
1276
1277
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1278
1279
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1280
1281
1282
1283
1284
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1285
1286
1287
1288
1289
1290
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
1292
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1293
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1294
                pt += extend_lens[i]
1295

1296
1297
1298
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1299
        # Build sampling info
1300
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1301
1302
            self,
            self.model_config.vocab_size,
1303
        )
1304

1305
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1306
        self.forward_mode = ForwardMode.MIXED
1307
        running_bs = running_batch.batch_size()
1308
1309
1310
1311
1312

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1313
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1314
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1315

1316
        self.merge_batch(running_batch)
1317
1318
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1319

1320
1321
1322
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1323
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1324
        self.prefix_lens.extend(
1325
            [
1326
                len(r.origin_input_ids) + len(r.output_ids) + delta
1327
1328
1329
                for r in running_batch.reqs
            ]
        )
1330
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1331
1332
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1333
        self.extend_logprob_start_lens.extend([0] * running_bs)
1334

1335
1336
1337
1338
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
1339
1340
1341
        # In the decoding phase, the length of a request's KV cache should be
        # the total length of the request minus 1
        return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1342

1343
1344
1345
1346
1347
1348
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1349

1350
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1351
1352
            return True

1353
1354
1355
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1356

1357
    def retract_decode(self, server_args: ServerArgs):
1358
        """Retract the decoding requests when there is not enough memory."""
1359
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1360
1361

        # TODO(lsyin): improve retraction policy for radix cache
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1387

Lianmin Zheng's avatar
Lianmin Zheng committed
1388
1389
1390
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1391
        while (
1392
            self.token_to_kv_pool_allocator.available_size()
1393
            < get_required_tokens(len(sorted_indices))
1394
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1395
1396
1397
1398
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1399
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1400
1401
1402
                ), "No space left for only one request"
                break

1403
            first_iter = False
1404
1405
1406
1407
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1408
1409
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1410
1411
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1412
                ]
1413
                self.token_to_kv_pool_allocator.free(token_indices)
1414
                self.req_to_token_pool.free(req.req_pool_idx)
1415
1416
            else:
                # TODO: apply more fine-grained retraction
1417
                last_uncached_pos = (
1418
1419
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1420
1421
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1422
                ]
1423
                self.token_to_kv_pool_allocator.free(token_indices)
1424
                self.req_to_token_pool.free(req.req_pool_idx)
1425
1426
1427
1428
1429
1430
1431

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1432
                    - self.token_to_kv_pool_allocator.available_size()
1433
1434
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
                self.tree_cache.evict(residual_size)
1436

1437
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1438

1439
        self.filter_batch(keep_indices=sorted_indices)
1440

Liangsheng Yin's avatar
Liangsheng Yin committed
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1451

1452
1453
1454
1455
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1456
1457
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1458
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1459
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1460
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1461
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1462
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1463
        self.extend_num_tokens = 0
1464
1465
1466
1467
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1468

1469
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1470
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1471
1472
        bs = len(self.reqs)

1473
        if self.spec_algorithm.is_eagle():
1474
1475
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1476
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1477

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1501
        # Update fields
1502
1503
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1504

1505
1506
1507
1508
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1509
            locs = self.seq_lens.clone()
1510

1511
        if self.enable_overlap:
1512
1513
1514
1515
1516
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1517
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1518

Lianmin Zheng's avatar
Lianmin Zheng committed
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1534
1535
    def filter_batch(
        self,
1536
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1537
1538
1539
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1540
1541
1542
1543
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1544
1545
1546
            keep_indices = [
                i
                for i in range(len(self.reqs))
1547
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1548
                and self.reqs[i] not in chunked_req_to_exclude
1549
1550
1551
            ]

        if keep_indices is None or len(keep_indices) == 0:
1552
1553
1554
1555
            # Filter out all requests
            self.reqs = []
            return

1556
        if len(keep_indices) == len(self.reqs):
1557
1558
1559
            # No need to filter
            return

1560
1561
1562
1563
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1564
        if self.model_config.is_encoder_decoder:
1565
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1566
1567
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1568
        self.reqs = [self.reqs[i] for i in keep_indices]
1569
1570
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1571
1572
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1573
        self.out_cache_loc = None
1574
        self.seq_lens_sum = self.seq_lens.sum().item()
1575
        self.output_ids = self.output_ids[keep_indices_device]
1576
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1577
        if self.return_logprob:
1578
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1579
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1580
1581
        else:
            self.top_logprobs_nums = None
1582
            self.token_ids_logprobs = None
1583

1584
        self.has_stream = any(req.stream for req in self.reqs)
1585
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1586

1587
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1588
        if self.spec_info:
1589
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1590

1591
    def merge_batch(self, other: "ScheduleBatch"):
1592
1593
1594
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1595
        self.sampling_info.merge_batch(other.sampling_info)
1596

1597
1598
1599
1600
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1601
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1602
1603
            [self.req_pool_indices, other.req_pool_indices]
        )
1604
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1605
        self.out_cache_loc = None
1606
        self.seq_lens_sum += other.seq_lens_sum
1607
        if self.output_ids is not None:
1608
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1609
1610
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1611
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1612
1613
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1614
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1615
1616
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1617
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1618
        self.reqs.extend(other.reqs)
1619
1620
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1621

1622
1623
1624
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1625
        self.return_hidden_states |= other.return_hidden_states
1626

1627
1628
1629
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1630
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1631
        if self.forward_mode.is_decode_or_idle():
1632
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1633
1634
1635
1636
1637
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1638
1639
        # Create seq_lens_cpu when needed
        if (
1640
1641
1642
1643
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1644
            or global_server_args_dict["attention_backend"] == "flashmla"
1645
            or global_server_args_dict["attention_backend"] == "fa3"
1646
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1647
            or global_server_args_dict["enable_two_batch_overlap"]
1648
1649
1650
1651
1652
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1653
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1654
1655
1656
1657
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1658

1659
1660
        global bid
        bid += 1
1661
        return ModelWorkerBatch(
1662
            bid=bid,
1663
1664
1665
1666
1667
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1668
            seq_lens_sum=self.seq_lens_sum,
1669
1670
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1671
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1672
            global_num_tokens=self.global_num_tokens,
1673
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1674
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1675
1676
            tbo_split_seq_index=self.tbo_split_seq_index,
            global_forward_mode=self.global_forward_mode,
1677
            seq_lens_cpu=seq_lens_cpu,
1678
            extend_num_tokens=self.extend_num_tokens,
1679
1680
1681
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1682
            multimodal_inputs=self.multimodal_inputs,
1683
1684
1685
1686
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1687
            lora_paths=[req.lora_path for req in self.reqs],
1688
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1689
            input_embeds=self.input_embeds,
1690
1691
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1692
            capture_hidden_mode=(
1693
                CaptureHiddenMode.FULL
1694
                if self.return_hidden_states
1695
1696
1697
1698
1699
1700
1701
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1702
            ),
1703
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1704
            launch_done=self.launch_done,
1705
1706
        )

1707
    def copy(self):
1708
        # Only contain fields that will be used by process_batch_result
1709
1710
        return ScheduleBatch(
            reqs=self.reqs,
1711
            model_config=self.model_config,
1712
            forward_mode=self.forward_mode,
1713
1714
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1715
            decoding_reqs=self.decoding_reqs,
1716
            spec_algorithm=self.spec_algorithm,
1717
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1718
1719
1720
1721
1722
1723
1724
1725
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1726

1727
@dataclasses.dataclass
1728
class ModelWorkerBatch:
1729
1730
    # The batch id
    bid: int
1731
1732
1733
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1734
    input_ids: torch.Tensor
1735
1736
1737
1738
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1739
    seq_lens_cpu: Optional[torch.Tensor]
1740
    # The indices of output tokens in the token_to_kv_pool_allocator
1741
1742
    out_cache_loc: torch.Tensor

1743
1744
1745
    # The sum of all sequence lengths
    seq_lens_sum: int

1746
1747
1748
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1749
    token_ids_logprobs: Optional[List[List[int]]]
1750

Ke Bao's avatar
Ke Bao committed
1751
1752
    # For DP attention
    global_num_tokens: Optional[List[int]]
1753
    global_num_tokens_for_logprob: Optional[List[int]]
1754
    can_run_dp_cuda_graph: bool
1755
1756
    tbo_split_seq_index: Optional[int]
    global_forward_mode: Optional[ForwardMode]
Ke Bao's avatar
Ke Bao committed
1757

1758
    # For extend
1759
    extend_num_tokens: Optional[int]
1760
1761
1762
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1763
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1764
1765

    # For multimodal
Mick's avatar
Mick committed
1766
    multimodal_inputs: Optional[List[MultimodalInputs]]
1767

1768
1769
1770
1771
1772
1773
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1774
1775
1776
1777
1778
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1779

Rin Intachuen's avatar
Rin Intachuen committed
1780
1781
1782
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1783
    # Speculative decoding
1784
    spec_algorithm: SpeculativeAlgorithm = None
1785
1786
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1787
    capture_hidden_mode: CaptureHiddenMode = None
1788

1789
1790
1791
    # Overlap event
    launch_done: Optional[threading.Event] = None

1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1810
1811
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1828
1829


1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
def get_last_loc(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    if global_server_args_dict["attention_backend"] != "torch_native":
        impl = get_last_loc_triton
    else:
        impl = get_last_loc_torch

    return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)


def get_last_loc_torch(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1848
1849
1850
1851
1852
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898


@triton.jit
def get_last_loc_kernel(
    req_to_token,
    req_pool_indices_tensor,
    prefix_lens_tensor,
    result,
    num_tokens,
    req_to_token_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
    mask = offset < num_tokens

    prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
    req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)

    token_mask = prefix_lens > 0
    token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
    tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)

    tl.store(result + offset, tokens, mask=mask)


def get_last_loc_triton(
    req_to_token: torch.Tensor,
    req_pool_indices_tensor: torch.Tensor,
    prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
    BLOCK_SIZE = 256
    num_tokens = prefix_lens_tensor.shape[0]
    result = torch.empty_like(prefix_lens_tensor)
    grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)

    get_last_loc_kernel[grid](
        req_to_token,
        req_pool_indices_tensor,
        prefix_lens_tensor,
        result,
        num_tokens,
        req_to_token.stride(0),
        BLOCK_SIZE,
    )
    return result