schedule_batch.py 67 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
43
import torch
44
45
import triton
import triton.language as tl
46

Liangsheng Yin's avatar
Liangsheng Yin committed
47
from sglang.global_config import global_config
48
from sglang.srt.configs.model_config import ModelConfig
49
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
51
52
53
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
    ScheduleBatchDisaggregationDecodeMixin,
)
Mick's avatar
Mick committed
54
from sglang.srt.layers.multimodal import gpu_tensor_hash
55
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
56
from sglang.srt.mem_cache.chunk_cache import ChunkCache
57
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
58
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
59
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
60
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
61
from sglang.srt.sampling.sampling_params import SamplingParams
62
from sglang.srt.server_args import ServerArgs
Mick's avatar
Mick committed
63
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
64

65
if TYPE_CHECKING:
66
67
68
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
69
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
70

71
72
# Put some global args for easy access
global_server_args_dict = {
73
    "attention_backend": ServerArgs.attention_backend,
74
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
75
    "deepep_mode": ServerArgs.deepep_mode,
76
    "device": ServerArgs.device,
77
    "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
78
    "disable_radix_cache": ServerArgs.disable_radix_cache,
79
80
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
    "enable_dp_attention": ServerArgs.enable_dp_attention,
81
    "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
82
    "enable_ep_moe": ServerArgs.enable_ep_moe,
83
    "deepep_config": ServerArgs.deepep_config,
84
    "enable_nan_detection": ServerArgs.enable_nan_detection,
85
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
86
    "max_micro_batch_size": ServerArgs.max_micro_batch_size,
87
    "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
88
    "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
89
    "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
90
91
92
93
94
    "sampling_backend": ServerArgs.sampling_backend,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "torchao_config": ServerArgs.torchao_config,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
95
    "ep_num_redundant_experts": ServerArgs.ep_num_redundant_experts,
96
97
}

Ying Sheng's avatar
Ying Sheng committed
98
99
100
logger = logging.getLogger(__name__)


101
102
103
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
104

105
    def to_json(self):
106
        raise NotImplementedError()
107
108
109


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
110
    def __init__(self, matched: Union[int, List[int]]):
111
112
113
        super().__init__()
        self.matched = matched

114
115
116
117
118
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
119
120


121
122
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
123
        super().__init__()
124
        self.matched = matched
125

126
127
128
129
130
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
131
132


133
134
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
135
        super().__init__()
136
        self.length = length
137

138
139
140
141
142
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
143
144
145


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
146
    def __init__(self, message=None, status_code=None, err_type=None):
147
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
148
        self.message = message or "Aborted"
149
150
        self.status_code = status_code
        self.err_type = err_type
151

152
153
154
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
155
            "message": self.message,
156
157
            "status_code": self.status_code,
            "err_type": self.err_type,
158
        }
159

Lianmin Zheng's avatar
Lianmin Zheng committed
160

Mick's avatar
Mick committed
161
162
163
164
165
166
167
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


168
@dataclasses.dataclass
Mick's avatar
Mick committed
169
170
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
171
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
172
    """
173

Mick's avatar
Mick committed
174
175
176
177
178
179
180
181
182
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
183
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
184
185

    # the real data, pixel_values or audio_features
186
187
188
189
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
    image_grid_thws: Union[torch.Tensor, np.ndarray] = None
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
190
191
192
193
194
195
196
197

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

198
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
199
200
    audio_feature_lens: Optional[List[torch.Tensor]] = None

201
202
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
203
204
205
206
207
208
209
210
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
211
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
212
213
        """

Mick's avatar
Mick committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
230
231
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
232
233
234
235
236
237
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

238
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
239
            if tensor.is_cuda:
240
241
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
242
243
244
245
246
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
247

Mick's avatar
Mick committed
248
249
        def hash_feature(f):
            if isinstance(f, list):
250
251
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
252
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
253
254
255
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
256
257
258
259
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
260

261
262
263
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
264
265
266
267
268
269
270
271
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
272
273
274
275
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
276
277
278
279

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
280
281
282
283
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
284
285

    def is_video(self):
286
287
288
289
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
290

291
292
293
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
294
295
296
297
    def validate(self):
        ...
        # TODO

298
299
300
301
302
303
304
305
306
307
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
308
309
310
311
312
313
314

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
315
    image_pad_len: Optional[list] = None
316
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
317

Yineng Zhang's avatar
Yineng Zhang committed
318
    # QWen2-VL related
319
    mrope_positions: Optional[torch.Tensor] = None
320
    mrope_position_delta: Optional[torch.Tensor] = None
321

Mick's avatar
Mick committed
322
    # image
Mick's avatar
Mick committed
323
    im_token_id: Optional[int] = None
324
325
326
327
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
328
329
330

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
331

Mick's avatar
Mick committed
332
    # audio
333
334
335
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
336

Liangsheng Yin's avatar
Liangsheng Yin committed
337
    @staticmethod
338
    def from_dict(obj: dict):
Mick's avatar
Mick committed
339
        ret = MultimodalInputs(
Mick's avatar
Mick committed
340
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
341
        )
342

Mick's avatar
Mick committed
343
        assert isinstance(ret.mm_items, list)
344
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
345
346
347

        for item in ret.mm_items:
            item.set_pad_value()
348
349

        optional_args = [
350
351
            "mrope_positions",
            "mrope_position_delta",
352
            "im_token_id",
Mick's avatar
Mick committed
353
354
355
356
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
357
358
            "audio_start_id",
            "audio_end_id",
359
            "audio_token_id",
360
361
362
363
364
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
365
366
        return ret

Mick's avatar
Mick committed
367
368
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
369
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
370
371
372

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
373
374
        return any(item.is_audio() for item in self.mm_items)

375
376
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
377
378

    def merge(self, other: MultimodalInputs):
379
380
381
        """
        merge image inputs when requests are being merged
        """
382

383
        # args needed to be merged
384
        optional_args = [
Mick's avatar
Mick committed
385
            "mm_items",
386
            "image_pad_len",
387
388
        ]
        for arg in optional_args:
389
390
391
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
392
393
394
395
396
397
398
399
400
401

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

402
403
404
405
406
407
408
409
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
410
411
412
413
414
415

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
416
        # other args would be kept intact
417

Liangsheng Yin's avatar
Liangsheng Yin committed
418

Lianmin Zheng's avatar
Lianmin Zheng committed
419
class Req:
420
    """The input and output status of a request."""
421

422
423
424
425
426
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
427
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
428
429
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
430
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
431
        stream: bool = False,
432
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
433
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
434
        input_embeds: Optional[List[List[float]]] = None,
435
        session_id: Optional[str] = None,
436
        custom_logit_processor: Optional[str] = None,
437
        return_hidden_states: bool = False,
438
        eos_token_ids: Optional[Set[int]] = None,
439
        bootstrap_host: Optional[str] = None,
440
        bootstrap_port: Optional[int] = None,
441
        bootstrap_room: Optional[int] = None,
442
    ):
443
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
445
        self.origin_input_text = origin_input_text
446
447
448
449
450
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
451
        self.origin_input_ids = origin_input_ids
452
453
454
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
455
        self.fill_ids = None
456
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
457
        self.input_embeds = input_embeds
458

Lianmin Zheng's avatar
Lianmin Zheng committed
459
        # Sampling info
460
461
462
463
464
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
465
        self.sampling_params = sampling_params
466
        self.custom_logit_processor = custom_logit_processor
467
        self.return_hidden_states = return_hidden_states
468
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
469

470
        # Memory pool info
471
        self.req_pool_idx: Optional[int] = None
472

473
474
475
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
        # Whether this request has finished output
        self.finished_output = None
478
479
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
480
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
481
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
482
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
483
        self.stream = stream
484
        self.eos_token_ids = eos_token_ids
485

486
        # For incremental decoding
487
488
489
490
491
492
493
494
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
495
496
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
497
        self.decoded_text = ""
498

499
        # For multimodal inputs
Mick's avatar
Mick committed
500
        self.multimodal_inputs: Optional[MultimodalInputs] = None
501

502
        # Prefix info
503
        # The indices to kv cache for the shared prefix.
504
        self.prefix_indices = []
505
        # Number of tokens to run prefill.
506
        self.extend_input_len = 0
507
508
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
509
        self.last_node = None
510
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
511

512
513
514
515
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
516

517
518
519
        # For retraction
        self.is_retracted = False

520
521
522
523
524
525
526
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

527
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
528
        self.return_logprob = return_logprob
529
        # Start index to compute logprob from.
530
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
531
        self.top_logprobs_num = top_logprobs_num
532
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
533
534
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
535

536
        # Logprobs (return values)
537
538
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
539
540
541
542
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
543
544
545
546
547
548
549
550
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
551
552

        if return_logprob:
553
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
554
555
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
556
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
559
560
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
563
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
564
565
566
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
567
        self.hidden_states: List[List[float]] = []
568

569
        # Embedding (return values)
570
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
571

572
        # Constrained decoding
573
        self.grammar: Optional[BaseGrammarObject] = None
574
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
575

576
        # The number of cached tokens that were already cached in the KV cache
577
        self.cached_tokens = 0
578
        self.already_computed = 0
579

580
581
582
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
583
584
585
586
587
588

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
589

Byron Hsu's avatar
Byron Hsu committed
590
        # For disaggregation
591
        self.bootstrap_host: str = bootstrap_host
592
        self.bootstrap_port: Optional[int] = bootstrap_port
593
        self.bootstrap_room: Optional[int] = bootstrap_room
594
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
595
596
597
598
599
600
601
602

        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

603
604
605
606
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
607
        self.metadata_buffer_index: int = -1
608

Lianmin Zheng's avatar
Lianmin Zheng committed
609
610
611
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

612
613
614
615
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

616
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
617
618
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
619
        else:
Mick's avatar
Mick committed
620
            self.multimodal_inputs.merge(image_inputs)
621

622
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
623
        # Whether request reached finished condition
624
625
        return self.finished_reason is not None

626
627
628
629
630
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
631
        self.fill_ids = self.origin_input_ids + self.output_ids
632
        if tree_cache is not None:
633
            # tree cache is None if the prefix is not computed with tree cache.
634
635
636
637
638
639
640
641
642
643
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
644
645
646
647
648
649
650
651
        elif enable_hierarchical_cache:
            # in case last_node is evicted during scheduling, we need to update the prefix_indices
            while self.last_node.evicted:
                self.prefix_indices = self.prefix_indices[
                    : -len(self.last_node.host_value)
                ]
                self.last_node = self.last_node.parent

652
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
653

654
    def adjust_max_prefix_ids(self):
655
656
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
657
658
659
660

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
661
662
663
664
665

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

666
        if self.return_logprob:
667
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
668

669
        max_prefix_len = max(max_prefix_len, 0)
670
        return self.fill_ids[:max_prefix_len]
671

Liangsheng Yin's avatar
Liangsheng Yin committed
672
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
673
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
674
675
676
677
678
679
680
681
682
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
683
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
684

685
    def check_finished(self):
686
        if self.finished():
687
688
            return

689
        if self.to_abort:
690
691
692
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
693
694
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
695
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
696
697
698
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
699
700
            return

701
702
703
704
705
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

706
        last_token_id = self.output_ids[-1]
707

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
725

726
        # Check stop strings
727
728
729
730
731
732
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
733
                if stop_str in tail_str or stop_str in self.decoded_text:
734
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
735
736
                    return

737
738
739
740
741
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
742
743
744
745
746
747
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
748
        self.already_computed = 0
749

Lianmin Zheng's avatar
Lianmin Zheng committed
750
751
752
753
754
755
756
757
758
759
760
761
762
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

763
764
765
766
767
768
769
770
771
772
773
774
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

Lianmin Zheng's avatar
Lianmin Zheng committed
775
    def __repr__(self):
776
        return (
777
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
778
779
780
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
781
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
782
783


Lianmin Zheng's avatar
Lianmin Zheng committed
784
# Batch id
785
786
787
bid = 0


788
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
789
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
790
    """Store all information of a batch on the scheduler."""
791

792
    # Request, memory pool, and cache
793
    reqs: List[Req]
794
    req_to_token_pool: ReqToTokenPool = None
795
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
796
    tree_cache: BasePrefixCache = None
797

798
    # Batch configs
799
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
800
    forward_mode: ForwardMode = None
801
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
804
805
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
806

807
808
809
    # Events
    launch_done: Optional[threading.Event] = None

810
811
812
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

813
    # Sampling info
814
    sampling_info: SamplingBatchInfo = None
815
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
816

817
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
818
    input_ids: torch.Tensor = None  # shape: [b], int64
819
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
820
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
821
    seq_lens: torch.Tensor = None  # shape: [b], int64
822
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
823
824
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
825

826
827
828
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

829
830
831
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
832
833
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
834
    global_num_tokens_for_logprob: Optional[List[int]] = None
835
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
836

837
    # For processing logprobs
838
    return_logprob: bool = False
839
    top_logprobs_nums: Optional[List[int]] = None
840
    token_ids_logprobs: Optional[List[List[int]]] = None
841

Lianmin Zheng's avatar
Lianmin Zheng committed
842
843
844
845
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

846
847
848
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
849
    extend_num_tokens: Optional[int] = None
850
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
851
    extend_logprob_start_lens: List[int] = None
852
853
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
854

Lianmin Zheng's avatar
Lianmin Zheng committed
855
    # For encoder-decoder architectures
856
857
858
859
860
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

861
862
863
    # Stream
    has_stream: bool = False

864
865
    # Has grammar
    has_grammar: bool = False
866

867
    # Device
868
869
    device: str = "cuda"

870
    # Speculative decoding
871
    spec_algorithm: SpeculativeAlgorithm = None
872
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
873

874
875
876
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

877
878
879
    # Whether to return hidden states
    return_hidden_states: bool = False

880
    @classmethod
881
882
    def init_new(
        cls,
883
        reqs: List[Req],
884
        req_to_token_pool: ReqToTokenPool,
885
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
886
887
888
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
889
        spec_algorithm: SpeculativeAlgorithm,
890
        enable_custom_logit_processor: bool,
891
        chunked_req: Optional[Req] = None,
892
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
893
894
        return_logprob = any(req.return_logprob for req in reqs)

895
896
897
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
898
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
899
            tree_cache=tree_cache,
900
            model_config=model_config,
901
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
902
            return_logprob=return_logprob,
903
            has_stream=any(req.stream for req in reqs),
904
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
905
            device=req_to_token_pool.device,
906
            spec_algorithm=spec_algorithm,
907
            enable_custom_logit_processor=enable_custom_logit_processor,
908
            return_hidden_states=any(req.return_hidden_states for req in reqs),
909
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
910
911
        )

912
    def batch_size(self):
913
        return len(self.reqs)
914

Lianmin Zheng's avatar
Lianmin Zheng committed
915
916
917
    def is_empty(self):
        return len(self.reqs) == 0

918
    def alloc_req_slots(self, num_reqs: int):
919
920
921
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
922
923
924
925
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
926
927
928
            )
        return req_pool_indices

929
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
930
931
932
933
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

934
935
936
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

937
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
938
939
940
941
942
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
943
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
944
945
946
947
948
949
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

950
951
952
953
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
954
955
956
957
958
959
960

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
961
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
962
963
964
965
966
967
968
969
970
971
972
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
973

974
975
976
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
977
978
979
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
980
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
981
982
983
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
984
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
985
986
987
988
989
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
990
991
992
993
994

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
995
996
997
998
999

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
1000
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
    ):
1002
1003
1004
1005
1006
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1007
1008
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1009
                )
1010

1011
1012
1013
1014
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1015
1016
1017
1018
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1019
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
1022
1023
1024
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1025
1026
1027
1028
1029

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1030

1031
1032
1033
1034
1035
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1036
            im = req.multimodal_inputs
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1048
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1061
                # NOTE: the encoder part should be considered as a whole
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1079
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1080
1081
            self.device, non_blocking=True
        )
1082
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1083
1084
1085
1086
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1087
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1088
1089
1090
1091
1092
1093
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1094
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1095
1096
1097
1098
1099
1100
1101
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

1102
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1103
1104
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1105
        # Allocate req slots
1106
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1107
1108
1109
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1110
        reqs = self.reqs
1111
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1112
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
1114
1115
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1116

Lianmin Zheng's avatar
Lianmin Zheng committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1130

Lianmin Zheng's avatar
Lianmin Zheng committed
1131
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1132
        input_embeds = []
1133
        extend_input_logprob_token_ids = []
1134
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1135

Lianmin Zheng's avatar
Lianmin Zheng committed
1136
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1137
            req.req_pool_idx = req_pool_indices[i]
1138
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1139

1140
            if pre_len > 0:
1141
1142
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1143
                )
1144

Rin Intachuen's avatar
Rin Intachuen committed
1145
1146
1147
1148
1149
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1150
1151
            multimodal_inputs.append(req.multimodal_inputs)

1152
1153
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1154
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1210

Lianmin Zheng's avatar
Lianmin Zheng committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1224
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1225
1226
1227
1228
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1229
1230
1231
1232
1233
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
1244
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1245

1246
1247
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1248
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1249

1250
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1251
1252
1253
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1254
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1255

1256
        # Write to req_to_token_pool
1257
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1258
1259
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1260
1261
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1262
1263
1264
1265
1266
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1267
1268
1269
1270
1271
1272
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
1274
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1275
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1276
                pt += extend_lens[i]
1277

1278
1279
1280
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1281
        # Build sampling info
1282
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1283
1284
            self,
            self.model_config.vocab_size,
1285
        )
1286

1287
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1288
        self.forward_mode = ForwardMode.MIXED
1289
        running_bs = running_batch.batch_size()
1290
1291
1292
1293
1294

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1295
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1296
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1297

1298
        self.merge_batch(running_batch)
1299
1300
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1301

1302
1303
1304
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1305
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1306
        self.prefix_lens.extend(
1307
            [
1308
                len(r.origin_input_ids) + len(r.output_ids) + delta
1309
1310
1311
                for r in running_batch.reqs
            ]
        )
1312
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1313
1314
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1315
        self.extend_logprob_start_lens.extend([0] * running_bs)
1316

1317
1318
1319
1320
1321
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
        return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1322

1323
1324
1325
1326
1327
1328
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1329

1330
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1331
1332
            return True

1333
1334
1335
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1336

1337
    def retract_decode(self, server_args: ServerArgs):
1338
        """Retract the decoding requests when there is not enough memory."""
1339
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1340
1341

        # TODO(lsyin): improve retraction policy for radix cache
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1367

Lianmin Zheng's avatar
Lianmin Zheng committed
1368
1369
1370
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1371
        while (
1372
            self.token_to_kv_pool_allocator.available_size()
1373
            < get_required_tokens(len(sorted_indices))
1374
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1375
1376
1377
1378
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1379
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1380
1381
1382
                ), "No space left for only one request"
                break

1383
            first_iter = False
1384
1385
1386
1387
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1388
1389
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1390
1391
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1392
                ]
1393
                self.token_to_kv_pool_allocator.free(token_indices)
1394
                self.req_to_token_pool.free(req.req_pool_idx)
1395
1396
            else:
                # TODO: apply more fine-grained retraction
1397
                last_uncached_pos = (
1398
1399
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1400
1401
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1402
                ]
1403
                self.token_to_kv_pool_allocator.free(token_indices)
1404
                self.req_to_token_pool.free(req.req_pool_idx)
1405
1406
1407
1408
1409
1410
1411

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1412
                    - self.token_to_kv_pool_allocator.available_size()
1413
1414
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1415
                self.tree_cache.evict(residual_size)
1416

1417
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1418

1419
        self.filter_batch(keep_indices=sorted_indices)
1420

Liangsheng Yin's avatar
Liangsheng Yin committed
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1431

1432
1433
1434
1435
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1436
1437
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1438
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1439
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1440
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1441
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1442
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1443
        self.extend_num_tokens = 0
1444
1445
1446
1447
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1448

1449
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1450
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1451
1452
        bs = len(self.reqs)

1453
        if self.spec_algorithm.is_eagle():
1454
1455
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1456
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1457

1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1481
        # Update fields
1482
1483
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1484

1485
1486
1487
1488
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1489
            locs = self.seq_lens.clone()
1490

1491
        if self.enable_overlap:
1492
1493
1494
1495
1496
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1497
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1498

Lianmin Zheng's avatar
Lianmin Zheng committed
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1514
1515
    def filter_batch(
        self,
1516
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1517
1518
1519
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1520
1521
1522
1523
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1524
1525
1526
            keep_indices = [
                i
                for i in range(len(self.reqs))
1527
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1528
                and self.reqs[i] not in chunked_req_to_exclude
1529
1530
1531
            ]

        if keep_indices is None or len(keep_indices) == 0:
1532
1533
1534
1535
            # Filter out all requests
            self.reqs = []
            return

1536
        if len(keep_indices) == len(self.reqs):
1537
1538
1539
            # No need to filter
            return

1540
1541
1542
1543
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1544
        if self.model_config.is_encoder_decoder:
1545
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1546
1547
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1548
        self.reqs = [self.reqs[i] for i in keep_indices]
1549
1550
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1551
1552
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1553
        self.out_cache_loc = None
1554
        self.seq_lens_sum = self.seq_lens.sum().item()
1555
        self.output_ids = self.output_ids[keep_indices_device]
1556
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1557
        if self.return_logprob:
1558
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1559
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1560
1561
        else:
            self.top_logprobs_nums = None
1562
            self.token_ids_logprobs = None
1563

1564
        self.has_stream = any(req.stream for req in self.reqs)
1565
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1566

1567
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1568
        if self.spec_info:
1569
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1570

1571
    def merge_batch(self, other: "ScheduleBatch"):
1572
1573
1574
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1575
        self.sampling_info.merge_batch(other.sampling_info)
1576

1577
1578
1579
1580
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1581
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1582
1583
            [self.req_pool_indices, other.req_pool_indices]
        )
1584
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1585
        self.out_cache_loc = None
1586
        self.seq_lens_sum += other.seq_lens_sum
1587
        if self.output_ids is not None:
1588
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1589
1590
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1591
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1592
1593
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1594
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1595
1596
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1597
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1598
        self.reqs.extend(other.reqs)
1599
1600
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1601

1602
1603
1604
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1605
        self.return_hidden_states |= other.return_hidden_states
1606

1607
1608
1609
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1610
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1611
        if self.forward_mode.is_decode_or_idle():
1612
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1613
1614
1615
1616
1617
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1618
1619
        # Create seq_lens_cpu when needed
        if (
1620
1621
1622
1623
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1624
            or global_server_args_dict["attention_backend"] == "flashmla"
1625
            or global_server_args_dict["attention_backend"] == "fa3"
1626
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1627
1628
1629
1630
1631
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1632
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1633
1634
1635
1636
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1637

1638
1639
        global bid
        bid += 1
1640
        return ModelWorkerBatch(
1641
            bid=bid,
1642
1643
1644
1645
1646
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1647
            seq_lens_sum=self.seq_lens_sum,
1648
1649
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1650
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1651
            global_num_tokens=self.global_num_tokens,
1652
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1653
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1654
            seq_lens_cpu=seq_lens_cpu,
1655
            extend_num_tokens=self.extend_num_tokens,
1656
1657
1658
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1659
            multimodal_inputs=self.multimodal_inputs,
1660
1661
1662
1663
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1664
            lora_paths=[req.lora_path for req in self.reqs],
1665
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1666
            input_embeds=self.input_embeds,
1667
1668
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1669
            capture_hidden_mode=(
1670
                CaptureHiddenMode.FULL
1671
                if self.return_hidden_states
1672
1673
1674
1675
1676
1677
1678
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1679
            ),
1680
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1681
            launch_done=self.launch_done,
1682
1683
        )

1684
    def copy(self):
1685
        # Only contain fields that will be used by process_batch_result
1686
1687
        return ScheduleBatch(
            reqs=self.reqs,
1688
            model_config=self.model_config,
1689
            forward_mode=self.forward_mode,
1690
1691
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1692
            decoding_reqs=self.decoding_reqs,
1693
            spec_algorithm=self.spec_algorithm,
1694
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1695
1696
1697
1698
1699
1700
1701
1702
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1703

1704
@dataclasses.dataclass
1705
class ModelWorkerBatch:
1706
1707
    # The batch id
    bid: int
1708
1709
1710
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1711
    input_ids: torch.Tensor
1712
1713
1714
1715
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1716
    seq_lens_cpu: Optional[torch.Tensor]
1717
    # The indices of output tokens in the token_to_kv_pool_allocator
1718
1719
    out_cache_loc: torch.Tensor

1720
1721
1722
    # The sum of all sequence lengths
    seq_lens_sum: int

1723
1724
1725
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1726
    token_ids_logprobs: Optional[List[List[int]]]
1727

Ke Bao's avatar
Ke Bao committed
1728
1729
    # For DP attention
    global_num_tokens: Optional[List[int]]
1730
    global_num_tokens_for_logprob: Optional[List[int]]
1731
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1732

1733
    # For extend
1734
    extend_num_tokens: Optional[int]
1735
1736
1737
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1738
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1739
1740

    # For multimodal
Mick's avatar
Mick committed
1741
    multimodal_inputs: Optional[List[MultimodalInputs]]
1742

1743
1744
1745
1746
1747
1748
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1749
1750
1751
1752
1753
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1754

Rin Intachuen's avatar
Rin Intachuen committed
1755
1756
1757
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1758
    # Speculative decoding
1759
    spec_algorithm: SpeculativeAlgorithm = None
1760
1761
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1762
    capture_hidden_mode: CaptureHiddenMode = None
1763

1764
1765
1766
    # Overlap event
    launch_done: Optional[threading.Event] = None

1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1785
1786
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1803
1804
1805
1806
1807
1808
1809
1810
1811


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )