schedule_batch.py 66.9 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing it in the future.
32
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import copy
35
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import hashlib
Ying Sheng's avatar
Ying Sheng committed
37
import logging
38
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from enum import Enum, auto
40
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
41

42
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
43
import torch
44
45
import triton
import triton.language as tl
46

Liangsheng Yin's avatar
Liangsheng Yin committed
47
from sglang.global_config import global_config
48
from sglang.srt.configs.model_config import ModelConfig
49
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
50
from sglang.srt.disaggregation.base import BaseKVSender
Byron Hsu's avatar
Byron Hsu committed
51
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
Mick's avatar
Mick committed
52
from sglang.srt.layers.multimodal import gpu_tensor_hash
53
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
54
from sglang.srt.mem_cache.chunk_cache import ChunkCache
55
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
56
from sglang.srt.metrics.collector import TimeStats
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
58
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
59
from sglang.srt.sampling.sampling_params import SamplingParams
60
from sglang.srt.server_args import ServerArgs
Mick's avatar
Mick committed
61
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
62

63
if TYPE_CHECKING:
64
65
66
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
67
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
68

69
70
# Put some global args for easy access
global_server_args_dict = {
71
    "attention_backend": ServerArgs.attention_backend,
72
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
73
    "deepep_mode": ServerArgs.deepep_mode,
74
    "device": ServerArgs.device,
75
    "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
76
    "disable_radix_cache": ServerArgs.disable_radix_cache,
77
78
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
    "enable_dp_attention": ServerArgs.enable_dp_attention,
79
    "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
80
    "enable_ep_moe": ServerArgs.enable_ep_moe,
81
    "deepep_config": ServerArgs.deepep_config,
82
    "enable_nan_detection": ServerArgs.enable_nan_detection,
83
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
84
    "max_micro_batch_size": ServerArgs.max_micro_batch_size,
85
    "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
86
    "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
87
    "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
88
89
90
91
92
    "sampling_backend": ServerArgs.sampling_backend,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "torchao_config": ServerArgs.torchao_config,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
93
94
}

Ying Sheng's avatar
Ying Sheng committed
95
96
97
logger = logging.getLogger(__name__)


98
99
100
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
101

102
    def to_json(self):
103
        raise NotImplementedError()
104
105
106


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
107
    def __init__(self, matched: Union[int, List[int]]):
108
109
110
        super().__init__()
        self.matched = matched

111
112
113
114
115
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
116
117


118
119
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
120
        super().__init__()
121
        self.matched = matched
122

123
124
125
126
127
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
128
129


130
131
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
132
        super().__init__()
133
        self.length = length
134

135
136
137
138
139
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
140
141
142


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
143
    def __init__(self, message=None, status_code=None, err_type=None):
144
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
145
        self.message = message or "Aborted"
146
147
        self.status_code = status_code
        self.err_type = err_type
148

149
150
151
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
152
            "message": self.message,
153
154
            "status_code": self.status_code,
            "err_type": self.err_type,
155
        }
156

Lianmin Zheng's avatar
Lianmin Zheng committed
157

Mick's avatar
Mick committed
158
159
160
161
162
163
164
class Modality(Enum):
    IMAGE = auto()
    MULTI_IMAGES = auto()
    VIDEO = auto()
    AUDIO = auto()


165
@dataclasses.dataclass
Mick's avatar
Mick committed
166
167
class MultimodalDataItem:
    """
Mick's avatar
Mick committed
168
    A single multimodal data, from a single image/video/audio or others
Mick's avatar
Mick committed
169
    """
170

Mick's avatar
Mick committed
171
172
173
174
175
176
177
178
179
    modality: Modality

    hash: int = None
    pad_value: int = None

    aspect_ratio_id: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    image_sizes: Tuple[int, int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
180
    image_offsets: Optional[list] = None
Mick's avatar
Mick committed
181
182

    # the real data, pixel_values or audio_features
183
184
185
186
    # data: Union[List[torch.Tensor], List[np.ndarray]]
    pixel_values: Union[torch.Tensor, np.ndarray] = None
    image_grid_thws: Union[torch.Tensor, np.ndarray] = None
    video_grid_thws: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
187
188
189
190
191
192
193
194

    image_emb_mask: Optional[torch.Tensor] = None
    image_spatial_crop: Optional[torch.Tensor] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None

    # [num_images, (n, w, h)]
    tgt_size: Tuple[int, int] = None

195
    audio_features: Union[torch.Tensor, np.ndarray] = None
Mick's avatar
Mick committed
196
197
    audio_feature_lens: Optional[List[torch.Tensor]] = None

198
199
    precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

Mick's avatar
Mick committed
200
201
202
203
204
205
206
207
    @staticmethod
    def is_empty_list(l):
        if l is None:
            return True
        return len([item for item in flatten_nested_list(l) if item is not None]) == 0

    def set_pad_value(self):
        """
Mick's avatar
Mick committed
208
        Set the pad value after first hashing the data
Mick's avatar
Mick committed
209
210
        """

Mick's avatar
Mick committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        def data_hash(data) -> int:
            hash_bytes = hashlib.sha256(data).digest()[:8]
            return int.from_bytes(hash_bytes, byteorder="big", signed=False)

        def tensor_hash(tensor_list) -> int:
            """
            hash a tensor or a tensor list
            """
            tensor = tensor_list
            if isinstance(tensor_list, list):
                tensor_list = flatten_nested_list(tensor_list)
                tensor_list = [
                    x.flatten() if isinstance(x, torch.Tensor) else x
                    for x in tensor_list
                ]
                tensor = torch.concat(tensor_list)
Mick's avatar
Mick committed
227
228
            if tensor.is_cuda:
                return gpu_tensor_hash(tensor)
Mick's avatar
Mick committed
229
230
231
232
233
234
            tensor = tensor.detach().contiguous()

            if tensor.dtype == torch.bfloat16:
                # memoryview() doesn't support PyTorch's BFloat16 dtype
                tensor = tensor.float()

235
            assert isinstance(tensor, torch.Tensor)
Mick's avatar
Mick committed
236
            if tensor.is_cuda:
237
238
                # TODO: improve this
                tensor_cpu = tensor.cpu()
Mick's avatar
Mick committed
239
240
241
242
243
            else:
                tensor_cpu = tensor

            mv = memoryview(tensor_cpu.numpy())
            return data_hash(mv.tobytes())
244

Mick's avatar
Mick committed
245
246
        def hash_feature(f):
            if isinstance(f, list):
247
248
                if isinstance(f[0], torch.Tensor):
                    return tensor_hash(f)
Mick's avatar
Mick committed
249
                return data_hash(tuple(flatten_nested_list(f)))
Mick's avatar
Mick committed
250
251
252
            elif isinstance(f, np.ndarray):
                arr = np.ascontiguousarray(f)
                arr_bytes = arr.tobytes()
Mick's avatar
Mick committed
253
254
255
256
                return data_hash(arr_bytes)
            elif isinstance(f, torch.Tensor):
                return tensor_hash([f])
            return data_hash(f)
Mick's avatar
Mick committed
257

258
259
260
        if self.precomputed_features is not None:
            self.hash = hash_feature(self.precomputed_features)
        elif self.is_audio():
Mick's avatar
Mick committed
261
262
263
264
265
266
267
268
            self.hash = hash_feature(self.audio_features)
        else:
            self.hash = hash_feature(self.pixel_values)

        assert self.hash is not None
        self.pad_value = self.hash % (1 << 30)

    def is_audio(self):
269
270
271
272
        return (self.modality == Modality.AUDIO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.audio_features)
        )
Mick's avatar
Mick committed
273
274
275
276

    def is_image(self):
        return (
            self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
277
278
279
280
        ) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
281
282

    def is_video(self):
283
284
285
286
        return (self.modality == Modality.VIDEO) and (
            self.precomputed_features is not None
            or not MultimodalDataItem.is_empty_list(self.pixel_values)
        )
Mick's avatar
Mick committed
287

288
289
290
    def is_valid(self) -> bool:
        return self.is_image() or self.is_video() or self.is_audio()

Mick's avatar
Mick committed
291
292
293
294
    def validate(self):
        ...
        # TODO

295
296
297
298
299
300
301
302
303
304
    @staticmethod
    def from_dict(obj: dict):
        kwargs = dict(obj)
        modality = kwargs.pop("modality")
        if isinstance(modality, str):
            modality = Modality[modality]
        ret = MultimodalDataItem(modality=modality, **kwargs)
        ret.validate()
        return ret

Mick's avatar
Mick committed
305
306
307
308
309
310
311

@dataclasses.dataclass
class MultimodalInputs:
    """The multimodal data related inputs."""

    # items of data
    mm_items: List[MultimodalDataItem]
312
    image_pad_len: Optional[list] = None
313
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
314

Yineng Zhang's avatar
Yineng Zhang committed
315
    # QWen2-VL related
316
    mrope_positions: Optional[torch.Tensor] = None
317
    mrope_position_delta: Optional[torch.Tensor] = None
318

Mick's avatar
Mick committed
319
    # image
Mick's avatar
Mick committed
320
    im_token_id: Optional[int] = None
321
322
323
324
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
325
326
327

    # video
    video_token_id: Optional[int] = None
Mick's avatar
Mick committed
328

Mick's avatar
Mick committed
329
    # audio
330
331
332
    audio_token_id: Optional[int] = None
    audio_start_id: Optional[int] = None
    audio_end_id: Optional[int] = None
Mick's avatar
Mick committed
333

Liangsheng Yin's avatar
Liangsheng Yin committed
334
    @staticmethod
335
    def from_dict(obj: dict):
Mick's avatar
Mick committed
336
        ret = MultimodalInputs(
Mick's avatar
Mick committed
337
            mm_items=obj["mm_items"],
Liangsheng Yin's avatar
Liangsheng Yin committed
338
        )
339

Mick's avatar
Mick committed
340
        assert isinstance(ret.mm_items, list)
341
        ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
Mick's avatar
Mick committed
342
343
344

        for item in ret.mm_items:
            item.set_pad_value()
345
346

        optional_args = [
347
348
            "mrope_positions",
            "mrope_position_delta",
349
            "im_token_id",
Mick's avatar
Mick committed
350
351
352
353
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
Mick's avatar
Mick committed
354
355
            "audio_start_id",
            "audio_end_id",
356
            "audio_token_id",
357
358
359
360
361
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
        return ret

Mick's avatar
Mick committed
364
365
    def contains_image_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
366
        return any(item.is_image() for item in self.mm_items)
Mick's avatar
Mick committed
367
368
369

    def contains_audio_inputs(self) -> bool:
        """ """
Mick's avatar
Mick committed
370
371
        return any(item.is_audio() for item in self.mm_items)

372
373
    def contains_mm_input(self) -> bool:
        return any(True for item in self.mm_items if item.is_valid())
Mick's avatar
Mick committed
374
375

    def merge(self, other: MultimodalInputs):
376
377
378
        """
        merge image inputs when requests are being merged
        """
379

380
        # args needed to be merged
381
        optional_args = [
Mick's avatar
Mick committed
382
            "mm_items",
383
            "image_pad_len",
384
385
        ]
        for arg in optional_args:
386
387
388
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
389
390
391
392
393
394
395
396
397
398

        mrope_positions = self.mrope_positions
        if mrope_positions is not None:
            if other.mrope_positions is None:
                self.mrope_positions = mrope_positions
            else:
                self.mrope_positions = torch.cat(
                    [self.mrope_positions, other.mrope_positions], dim=1
                )

399
400
401
402
403
404
405
406
        mrope_position_delta = self.mrope_position_delta
        if mrope_position_delta is not None:
            if other.mrope_position_delta is None:
                self.mrope_position_delta = mrope_position_delta
            else:
                self.mrope_position_delta = torch.cat(
                    [self.mrope_position_delta, other.mrope_position_delta], dim=0
                )
407
408
409
410
411
412

        for key, val in other.__dict__.items():
            if "_id" in key:
                # set token_ids
                if getattr(self, key, None) is None:
                    setattr(self, key, getattr(other, key, None))
413
        # other args would be kept intact
414

Liangsheng Yin's avatar
Liangsheng Yin committed
415

Lianmin Zheng's avatar
Lianmin Zheng committed
416
class Req:
417
    """The input and output status of a request."""
418

419
420
421
422
423
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
424
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
425
426
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
427
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
428
        stream: bool = False,
429
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
430
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
431
        input_embeds: Optional[List[List[float]]] = None,
432
        session_id: Optional[str] = None,
433
        custom_logit_processor: Optional[str] = None,
434
        return_hidden_states: bool = False,
435
        eos_token_ids: Optional[Set[int]] = None,
436
        bootstrap_host: Optional[str] = None,
437
        bootstrap_port: Optional[int] = None,
438
        bootstrap_room: Optional[int] = None,
439
    ):
440
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
441
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
442
        self.origin_input_text = origin_input_text
443
444
445
446
447
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
448
        self.origin_input_ids = origin_input_ids
449
450
451
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
452
        self.fill_ids = None
453
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
454
        self.input_embeds = input_embeds
455

Lianmin Zheng's avatar
Lianmin Zheng committed
456
        # Sampling info
457
458
459
460
461
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
462
        self.sampling_params = sampling_params
463
        self.custom_logit_processor = custom_logit_processor
464
        self.return_hidden_states = return_hidden_states
465
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
466

467
        # Memory pool info
468
        self.req_pool_idx: Optional[int] = None
469

470
471
472
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
        # Whether this request has finished output
        self.finished_output = None
475
476
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
477
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
478
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
Lianmin Zheng's avatar
Lianmin Zheng committed
479
        self.to_abort_message: str = None
Lianmin Zheng's avatar
Lianmin Zheng committed
480
        self.stream = stream
481
        self.eos_token_ids = eos_token_ids
482

483
        # For incremental decoding
484
485
486
487
488
489
490
491
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
492
493
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
494
        self.decoded_text = ""
495

496
        # For multimodal inputs
Mick's avatar
Mick committed
497
        self.multimodal_inputs: Optional[MultimodalInputs] = None
498

499
        # Prefix info
500
        # The indices to kv cache for the shared prefix.
501
        self.prefix_indices = []
502
        # Number of tokens to run prefill.
503
        self.extend_input_len = 0
504
505
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
506
        self.last_node = None
507
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
508

509
510
511
512
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
513

514
515
516
        # For retraction
        self.is_retracted = False

517
518
519
520
521
522
523
        # Incremental streamining
        self.send_token_offset: int = 0
        self.send_decode_id_offset: int = 0
        # TODO (Byron): send_output_token_logprobs_offset and send_decode_id_offset can be different in disaggregation mode
        # because the decode server does not have the first output token logprobs
        self.send_output_token_logprobs_offset: int = 0

524
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
525
        self.return_logprob = return_logprob
526
        # Start index to compute logprob from.
527
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
528
        self.top_logprobs_num = top_logprobs_num
529
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
530
531
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
532

533
        # Logprobs (return values)
534
535
        # True means the input logprob has been already sent to detokenizer.
        self.input_logprob_sent: bool = False
536
537
538
539
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
540
541
542
543
544
545
546
547
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
548
549

        if return_logprob:
550
            # shape: (bs, 1)
Lianmin Zheng's avatar
Lianmin Zheng committed
551
552
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
553
            # shape: (bs, k)
Lianmin Zheng's avatar
Lianmin Zheng committed
554
555
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
556
557
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
560
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
561
562
563
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
564
        self.hidden_states: List[List[float]] = []
565

566
        # Embedding (return values)
567
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
568

569
        # Constrained decoding
570
        self.grammar: Optional[BaseGrammarObject] = None
571
        self.grammar_wait_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
572

573
        # The number of cached tokens that were already cached in the KV cache
574
        self.cached_tokens = 0
575
        self.already_computed = 0
576

577
578
579
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
580
581
582
583
584
585

        # For metrics
        self.time_stats: TimeStats = TimeStats()
        self.has_log_time_stats: bool = False
        self.queue_time_start = None
        self.queue_time_end = None
586

Byron Hsu's avatar
Byron Hsu committed
587
        # For disaggregation
588
        self.bootstrap_host: str = bootstrap_host
589
        self.bootstrap_port: Optional[int] = bootstrap_port
590
        self.bootstrap_room: Optional[int] = bootstrap_room
591
        self.disagg_kv_sender: Optional[BaseKVSender] = None
Byron Hsu's avatar
Byron Hsu committed
592
593
594
595
596
597
598
599

        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

600
601
602
603
        # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
        # This is because kv is not ready in `process_prefill_chunk`.
        # We use `tmp_end_idx` to store the end index of the kv cache to send.
        self.tmp_end_idx: int = -1
Lianmin Zheng's avatar
Lianmin Zheng committed
604
        self.metadata_buffer_index: int = -1
605

Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
608
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

609
610
611
612
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

613
    def extend_image_inputs(self, image_inputs):
Mick's avatar
Mick committed
614
615
        if self.multimodal_inputs is None:
            self.multimodal_inputs = image_inputs
616
        else:
Mick's avatar
Mick committed
617
            self.multimodal_inputs.merge(image_inputs)
618

619
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
620
        # Whether request reached finished condition
621
622
        return self.finished_reason is not None

623
624
625
626
627
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
628
        self.fill_ids = self.origin_input_ids + self.output_ids
629
        if tree_cache is not None:
630
            # tree cache is None if the prefix is not computed with tree cache.
631
632
633
634
635
636
637
638
639
640
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
Zhiqiang Xie's avatar
Zhiqiang Xie committed
641
642
643
644
645
646
647
648
        elif enable_hierarchical_cache:
            # in case last_node is evicted during scheduling, we need to update the prefix_indices
            while self.last_node.evicted:
                self.prefix_indices = self.prefix_indices[
                    : -len(self.last_node.host_value)
                ]
                self.last_node = self.last_node.parent

649
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
650

651
    def adjust_max_prefix_ids(self):
652
653
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
654
655
656
657

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
658
659
660
661
662

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

663
        if self.return_logprob:
664
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
665

666
        max_prefix_len = max(max_prefix_len, 0)
667
        return self.fill_ids[:max_prefix_len]
668

Liangsheng Yin's avatar
Liangsheng Yin committed
669
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
670
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
671
672
673
674
675
676
677
678
679
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
680
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
681

682
    def check_finished(self):
683
        if self.finished():
684
685
            return

686
        if self.to_abort:
687
688
689
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
690
691
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
692
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
693
694
695
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
696
697
            return

698
699
700
701
702
        if self.grammar is not None:
            if self.grammar.is_terminated():
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
                return

703
        last_token_id = self.output_ids[-1]
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
722

723
        # Check stop strings
724
725
726
727
728
729
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
730
                if stop_str in tail_str or stop_str in self.decoded_text:
731
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
732
733
                    return

734
735
736
737
738
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
739
740
741
742
743
744
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
745
        self.already_computed = 0
746

Lianmin Zheng's avatar
Lianmin Zheng committed
747
748
749
750
751
752
753
754
755
756
757
758
759
    def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)

    def load_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
        token_indices = req_to_token_pool.req_to_token[
            self.req_pool_idx, : self.seqlen - 1
        ]
        token_to_kv_pool_allocator.load_cpu_copy(self.kv_cache_cpu, token_indices)
        del self.kv_cache_cpu

760
761
762
763
764
765
766
767
768
769
770
771
    def log_time_stats(self):
        # If overlap schedule, we schedule one decode batch ahead so this gets called twice.
        if self.has_log_time_stats is True:
            return

        if self.bootstrap_room is not None:
            prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        else:
            prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
        logger.info(f"{prefix}: {self.time_stats}")
        self.has_log_time_stats = True

Lianmin Zheng's avatar
Lianmin Zheng committed
772
    def __repr__(self):
773
        return (
774
            f"Req(rid={self.rid}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
775
776
777
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}, "
            f"{self.grammar=}, "
            f"{self.sampling_params=})"
778
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
779
780


Lianmin Zheng's avatar
Lianmin Zheng committed
781
# Batch id
782
783
784
bid = 0


785
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
786
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
787
    """Store all information of a batch on the scheduler."""
788

789
    # Request, memory pool, and cache
790
    reqs: List[Req]
791
    req_to_token_pool: ReqToTokenPool = None
792
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
793
    tree_cache: BasePrefixCache = None
794

795
    # Batch configs
796
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
797
    forward_mode: ForwardMode = None
798
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
799
800
801
802
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
803

804
805
806
    # Events
    launch_done: Optional[threading.Event] = None

807
808
809
    # For chunked prefill in PP
    chunked_req: Optional[Req] = None

810
    # Sampling info
811
    sampling_info: SamplingBatchInfo = None
812
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
813

814
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
815
    input_ids: torch.Tensor = None  # shape: [b], int64
816
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
817
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
818
    seq_lens: torch.Tensor = None  # shape: [b], int64
819
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
820
821
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
822

823
824
825
    # For multimodal inputs
    multimodal_inputs: Optional[List] = None

826
827
828
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
829
830
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
831
    global_num_tokens_for_logprob: Optional[List[int]] = None
832
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
833

834
    # For processing logprobs
835
    return_logprob: bool = False
836
    top_logprobs_nums: Optional[List[int]] = None
837
    token_ids_logprobs: Optional[List[List[int]]] = None
838

Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
841
842
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

843
844
845
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
846
    extend_num_tokens: Optional[int] = None
847
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
848
    extend_logprob_start_lens: List[int] = None
849
850
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
851

Lianmin Zheng's avatar
Lianmin Zheng committed
852
    # For encoder-decoder architectures
853
854
855
856
857
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

858
859
860
    # Stream
    has_stream: bool = False

861
862
    # Has grammar
    has_grammar: bool = False
863

864
    # Device
865
866
    device: str = "cuda"

867
    # Speculative decoding
868
    spec_algorithm: SpeculativeAlgorithm = None
869
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
870

871
872
873
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

874
875
876
    # Whether to return hidden states
    return_hidden_states: bool = False

877
    @classmethod
878
879
    def init_new(
        cls,
880
        reqs: List[Req],
881
        req_to_token_pool: ReqToTokenPool,
882
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
883
884
885
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
886
        spec_algorithm: SpeculativeAlgorithm,
887
        enable_custom_logit_processor: bool,
888
        chunked_req: Optional[Req] = None,
889
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
890
891
        return_logprob = any(req.return_logprob for req in reqs)

892
893
894
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
895
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
896
            tree_cache=tree_cache,
897
            model_config=model_config,
898
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
899
            return_logprob=return_logprob,
900
            has_stream=any(req.stream for req in reqs),
901
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
902
            device=req_to_token_pool.device,
903
            spec_algorithm=spec_algorithm,
904
            enable_custom_logit_processor=enable_custom_logit_processor,
905
            return_hidden_states=any(req.return_hidden_states for req in reqs),
906
            chunked_req=chunked_req,
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908
        )

909
    def batch_size(self):
910
        return len(self.reqs)
911

Lianmin Zheng's avatar
Lianmin Zheng committed
912
913
914
    def is_empty(self):
        return len(self.reqs) == 0

915
    def alloc_req_slots(self, num_reqs: int):
916
917
918
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
919
920
921
922
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
923
924
925
            )
        return req_pool_indices

926
    def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
Lianmin Zheng's avatar
Lianmin Zheng committed
927
928
929
930
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

931
932
933
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

934
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
935
936
937
938
939
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
940
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
941
942
943
944
945
946
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

947
948
949
950
        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
953
954
955
956
957

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
958
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
959
960
961
962
963
964
965
966
967
968
969
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
970

971
972
973
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

Lianmin Zheng's avatar
Lianmin Zheng committed
974
975
976
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
977
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
978
979
980
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
981
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
982
983
984
985
986
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
987
988
989
990
991

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
Lianmin Zheng's avatar
Lianmin Zheng committed
992
993
994
995
996

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
997
        backup_state: bool = False,
Lianmin Zheng's avatar
Lianmin Zheng committed
998
    ):
999
1000
1001
1002
1003
        if self.tree_cache is not None:
            if (
                self.token_to_kv_pool_allocator.available_size()
                < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
1006
                )
1007

1008
1009
1010
1011
        if backup_state:
            state = self.token_to_kv_pool_allocator.backup_state()

        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
Lianmin Zheng's avatar
Lianmin Zheng committed
1012
1013
1014
1015
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
1016
                f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
1018
1019
1020
1021
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
1022
1023
1024
1025
1026

        if backup_state:
            return out_cache_loc, state
        else:
            return out_cache_loc
1027

1028
1029
1030
1031
1032
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
Mick's avatar
Mick committed
1033
            im = req.multimodal_inputs
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

1045
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
1058
                # NOTE: the encoder part should be considered as a whole
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
1076
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
1077
1078
            self.device, non_blocking=True
        )
1079
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1080
1081
1082
1083
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1084
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1085
1086
1087
1088
1089
1090
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
1091
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
1092
1093
1094
1095
1096
1097
1098
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

1099
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1100
1101
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
1102
        # Allocate req slots
1103
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
1105
1106
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
1107
        reqs = self.reqs
1108
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
1109
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1110
1111
1112
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1113

Lianmin Zheng's avatar
Lianmin Zheng committed
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1127

Lianmin Zheng's avatar
Lianmin Zheng committed
1128
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
1129
        input_embeds = []
1130
        extend_input_logprob_token_ids = []
1131
        multimodal_inputs = []
Rin Intachuen's avatar
Rin Intachuen committed
1132

Lianmin Zheng's avatar
Lianmin Zheng committed
1133
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1134
            req.req_pool_idx = req_pool_indices[i]
1135
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
1136

1137
            if pre_len > 0:
1138
1139
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
1140
                )
1141

Rin Intachuen's avatar
Rin Intachuen committed
1142
1143
1144
1145
1146
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

1147
1148
            multimodal_inputs.append(req.multimodal_inputs)

1149
1150
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
1151
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1152

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1207

Lianmin Zheng's avatar
Lianmin Zheng committed
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1221
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
1222
1223
1224
1225
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
1226
1227
1228
1229
1230
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        for mm_input in multimodal_inputs:
            if mm_input is None:
                continue
            for mm_item in mm_input.mm_items:
                pixel_values = getattr(mm_item, "pixel_values", None)
                if isinstance(pixel_values, torch.Tensor):
                    mm_item.pixel_values = pixel_values.to(
                        self.device, non_blocking=True
                    )
        self.multimodal_inputs = multimodal_inputs
1241
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1242

1243
1244
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1245
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1246

1247
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1248
1249
1250
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1251
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1252

1253
        # Write to req_to_token_pool
1254
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1255
1256
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1257
1258
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1259
1260
1261
1262
1263
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1264
1265
1266
1267
1268
1269
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1270
1271
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1272
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
                pt += extend_lens[i]
1274

1275
1276
1277
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1278
        # Build sampling info
1279
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1280
1281
            self,
            self.model_config.vocab_size,
1282
        )
1283

1284
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1285
        self.forward_mode = ForwardMode.MIXED
1286
        running_bs = running_batch.batch_size()
1287
1288
1289
1290
1291

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1292
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1293
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1294

1295
        self.merge_batch(running_batch)
1296
1297
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1298

1299
1300
1301
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1302
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1303
        self.prefix_lens.extend(
1304
            [
1305
                len(r.origin_input_ids) + len(r.output_ids) + delta
1306
1307
1308
                for r in running_batch.reqs
            ]
        )
1309
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1310
1311
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1312
        self.extend_logprob_start_lens.extend([0] * running_bs)
1313

1314
1315
1316
1317
1318
    def new_page_count_next_decode(self):
        page_size = self.token_to_kv_pool_allocator.page_size
        if page_size == 1:
            return len(self.reqs)
        return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1319

1320
1321
1322
1323
1324
1325
    def check_decode_mem(self, buf_multiplier=1):
        tokens_required = (
            self.new_page_count_next_decode()
            * buf_multiplier
            * self.token_to_kv_pool_allocator.page_size
        )
1326

1327
        if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
1328
1329
            return True

1330
1331
1332
        self.tree_cache.evict(tokens_required)

        return self.token_to_kv_pool_allocator.available_size() >= tokens_required
1333

1334
    def retract_decode(self, server_args: ServerArgs):
1335
        """Retract the decoding requests when there is not enough memory."""
1336
        sorted_indices = list(range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
1337
1338

        # TODO(lsyin): improve retraction policy for radix cache
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1364

Lianmin Zheng's avatar
Lianmin Zheng committed
1365
1366
1367
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1368
        while (
1369
            self.token_to_kv_pool_allocator.available_size()
1370
            < get_required_tokens(len(sorted_indices))
1371
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1372
1373
1374
1375
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1376
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1377
1378
1379
                ), "No space left for only one request"
                break

1380
            first_iter = False
1381
1382
1383
1384
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1385
1386
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1387
1388
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1389
                ]
1390
                self.token_to_kv_pool_allocator.free(token_indices)
1391
                self.req_to_token_pool.free(req.req_pool_idx)
1392
1393
            else:
                # TODO: apply more fine-grained retraction
1394
                last_uncached_pos = (
1395
1396
                    len(req.prefix_indices) // server_args.page_size
                ) * server_args.page_size
1397
1398
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1399
                ]
1400
                self.token_to_kv_pool_allocator.free(token_indices)
1401
                self.req_to_token_pool.free(req.req_pool_idx)
1402
1403
1404
1405
1406
1407
1408

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1409
                    - self.token_to_kv_pool_allocator.available_size()
1410
1411
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1412
                self.tree_cache.evict(residual_size)
1413

1414
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1415

1416
        self.filter_batch(keep_indices=sorted_indices)
1417

Liangsheng Yin's avatar
Liangsheng Yin committed
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1428

1429
1430
1431
1432
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1433
1434
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1436
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1437
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1438
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1439
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1440
        self.extend_num_tokens = 0
1441
1442
1443
1444
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1445

1446
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1447
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1448
1449
        bs = len(self.reqs)

1450
        if self.spec_algorithm.is_eagle():
1451
1452
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1453
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1454

1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1478
        # Update fields
1479
1480
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1481

1482
1483
1484
1485
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
            locs = self.seq_lens.clone()
1487

1488
        if self.enable_overlap:
1489
1490
1491
1492
1493
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1494
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1495

Lianmin Zheng's avatar
Lianmin Zheng committed
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1511
1512
    def filter_batch(
        self,
1513
        chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
1514
1515
1516
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
1517
1518
1519
1520
            if isinstance(chunked_req_to_exclude, Req):
                chunked_req_to_exclude = [chunked_req_to_exclude]
            elif chunked_req_to_exclude is None:
                chunked_req_to_exclude = []
1521
1522
1523
            keep_indices = [
                i
                for i in range(len(self.reqs))
1524
                if not self.reqs[i].finished()
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
                and self.reqs[i] not in chunked_req_to_exclude
1526
1527
1528
            ]

        if keep_indices is None or len(keep_indices) == 0:
1529
1530
1531
1532
            # Filter out all requests
            self.reqs = []
            return

1533
        if len(keep_indices) == len(self.reqs):
1534
1535
1536
            # No need to filter
            return

1537
1538
1539
1540
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1541
        if self.model_config.is_encoder_decoder:
1542
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1543
1544
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1545
        self.reqs = [self.reqs[i] for i in keep_indices]
1546
1547
        if self.multimodal_inputs is not None:
            self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1548
1549
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1550
        self.out_cache_loc = None
1551
        self.seq_lens_sum = self.seq_lens.sum().item()
1552
        self.output_ids = self.output_ids[keep_indices_device]
1553
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1554
        if self.return_logprob:
1555
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1556
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1557
1558
        else:
            self.top_logprobs_nums = None
1559
            self.token_ids_logprobs = None
1560

1561
        self.has_stream = any(req.stream for req in self.reqs)
1562
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1563

1564
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1565
        if self.spec_info:
1566
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1567

1568
    def merge_batch(self, other: "ScheduleBatch"):
1569
1570
1571
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1572
        self.sampling_info.merge_batch(other.sampling_info)
1573

1574
1575
1576
1577
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1578
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1579
1580
            [self.req_pool_indices, other.req_pool_indices]
        )
1581
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1582
        self.out_cache_loc = None
1583
        self.seq_lens_sum += other.seq_lens_sum
1584
        if self.output_ids is not None:
1585
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1586
1587
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1588
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1589
1590
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1591
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1592
1593
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1594
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1595
        self.reqs.extend(other.reqs)
1596
1597
        if self.multimodal_inputs is not None:
            self.multimodal_inputs.extend(other.multimodal_inputs)
1598

1599
1600
1601
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1602
        self.return_hidden_states |= other.return_hidden_states
1603

1604
1605
1606
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1607
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1608
        if self.forward_mode.is_decode_or_idle():
1609
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1610
1611
1612
1613
1614
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1615
1616
        # Create seq_lens_cpu when needed
        if (
1617
1618
1619
1620
            (
                global_server_args_dict["use_mla_backend"]
                and global_server_args_dict["attention_backend"] == "flashinfer"
            )
1621
            or global_server_args_dict["attention_backend"] == "flashmla"
1622
            or global_server_args_dict["attention_backend"] == "fa3"
1623
            or global_server_args_dict["attention_backend"] == "cutlass_mla"
1624
1625
1626
1627
1628
        ):
            seq_lens_cpu = self.seq_lens.cpu()
        else:
            seq_lens_cpu = None

1629
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1630
1631
1632
1633
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1634

1635
1636
        global bid
        bid += 1
1637
        return ModelWorkerBatch(
1638
            bid=bid,
1639
1640
1641
1642
1643
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1644
            seq_lens_sum=self.seq_lens_sum,
1645
1646
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1647
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1648
            global_num_tokens=self.global_num_tokens,
1649
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1650
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1651
            seq_lens_cpu=seq_lens_cpu,
1652
            extend_num_tokens=self.extend_num_tokens,
1653
1654
1655
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1656
            multimodal_inputs=self.multimodal_inputs,
1657
1658
1659
1660
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1661
            lora_paths=[req.lora_path for req in self.reqs],
1662
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1663
            input_embeds=self.input_embeds,
1664
1665
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1666
            capture_hidden_mode=(
1667
                CaptureHiddenMode.FULL
1668
                if self.return_hidden_states
1669
1670
1671
1672
1673
1674
1675
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1676
            ),
1677
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1678
            launch_done=self.launch_done,
1679
1680
        )

1681
    def copy(self):
1682
        # Only contain fields that will be used by process_batch_result
1683
1684
        return ScheduleBatch(
            reqs=self.reqs,
1685
            model_config=self.model_config,
1686
            forward_mode=self.forward_mode,
1687
1688
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1689
            decoding_reqs=self.decoding_reqs,
1690
            spec_algorithm=self.spec_algorithm,
1691
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1692
1693
1694
1695
1696
1697
1698
1699
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1700

1701
@dataclasses.dataclass
1702
class ModelWorkerBatch:
1703
1704
    # The batch id
    bid: int
1705
1706
1707
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1708
    input_ids: torch.Tensor
1709
1710
1711
1712
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1713
    seq_lens_cpu: Optional[torch.Tensor]
1714
    # The indices of output tokens in the token_to_kv_pool_allocator
1715
1716
    out_cache_loc: torch.Tensor

1717
1718
1719
    # The sum of all sequence lengths
    seq_lens_sum: int

1720
1721
1722
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1723
    token_ids_logprobs: Optional[List[List[int]]]
1724

Ke Bao's avatar
Ke Bao committed
1725
1726
    # For DP attention
    global_num_tokens: Optional[List[int]]
1727
    global_num_tokens_for_logprob: Optional[List[int]]
1728
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1729

1730
    # For extend
1731
    extend_num_tokens: Optional[int]
1732
1733
1734
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1735
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1736
1737

    # For multimodal
Mick's avatar
Mick committed
1738
    multimodal_inputs: Optional[List[MultimodalInputs]]
1739

1740
1741
1742
1743
1744
1745
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1746
1747
1748
1749
1750
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1751

Rin Intachuen's avatar
Rin Intachuen committed
1752
1753
1754
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1755
    # Speculative decoding
1756
    spec_algorithm: SpeculativeAlgorithm = None
1757
1758
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1759
    capture_hidden_mode: CaptureHiddenMode = None
1760

1761
1762
1763
    # Overlap event
    launch_done: Optional[threading.Event] = None

1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1782
1783
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1800
1801
1802
1803
1804
1805
1806
1807
1808


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )