schedule_batch.py 57.8 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import copy
33
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
34
import logging
35
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import torch
39
40
import triton
import triton.language as tl
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.global_config import global_config
43
from sglang.srt.configs.model_config import ModelConfig
44
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
Byron Hsu's avatar
Byron Hsu committed
45
46
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
47
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
48
from sglang.srt.mem_cache.chunk_cache import ChunkCache
49
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
50
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
51
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
from sglang.srt.sampling.sampling_params import SamplingParams
53
from sglang.srt.server_args import ServerArgs
54
from sglang.srt.utils import get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
55

56
if TYPE_CHECKING:
57
58
59
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
60
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
61

62
63
# Put some global args for easy access
global_server_args_dict = {
64
65
66
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
67
    "disable_mla": ServerArgs.disable_mla,
68
    "torchao_config": ServerArgs.torchao_config,
69
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
70
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
71
    "enable_ep_moe": ServerArgs.enable_ep_moe,
72
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
73
    "device": ServerArgs.device,
74
75
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
76
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
lukec's avatar
lukec committed
77
    "enable_flashmla": ServerArgs.enable_flashmla,
78
    "disable_radix_cache": ServerArgs.disable_radix_cache,
79
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
    "chunked_prefill_size": ServerArgs.chunked_prefill_size,
81
82
}

Ying Sheng's avatar
Ying Sheng committed
83
84
85
logger = logging.getLogger(__name__)


86
87
88
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
89

90
    def to_json(self):
91
        raise NotImplementedError()
92
93
94


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
95
    def __init__(self, matched: Union[int, List[int]]):
96
97
98
        super().__init__()
        self.matched = matched

99
100
101
102
103
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
104
105


106
107
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
108
        super().__init__()
109
        self.matched = matched
110

111
112
113
114
115
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
116
117


118
119
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
120
        super().__init__()
121
        self.length = length
122

123
124
125
126
127
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
128
129
130


class FINISH_ABORT(BaseFinishReason):
131
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
132
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
133
        self.message = message
134
135
        self.status_code = status_code
        self.err_type = err_type
136

137
138
139
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
140
            "message": self.message,
141
142
            "status_code": self.status_code,
            "err_type": self.err_type,
143
        }
144

Lianmin Zheng's avatar
Lianmin Zheng committed
145

146
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
147
class ImageInputs:
148
149
    """The image related inputs."""

150
    pixel_values: Union[torch.Tensor, np.array]
151
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
154
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
155
156
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
157
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
158

159
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
162

Yineng Zhang's avatar
Yineng Zhang committed
163
    # QWen2-VL related
164
165
    # [num_of_images, t, h, w]
    image_grid_thws: torch.Tensor = None
166
    mrope_position_delta: Optional[torch.Tensor] = None
167
168
169
170
    # Qwen2-VL video related
    video_token_id: Optional[int] = None
    video_grid_thws: List[Tuple[int, int, int]] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
171

172
    # deepseek vl2 related
173
    images_emb_mask: Optional[List[torch.Tensor]] = None
174
175
    image_spatial_crop: Optional[List[torch.Tensor]] = None

176
177
    # The id of the single-image placeholder token
    im_token_id: Optional[torch.Tensor] = None
178

Mick's avatar
Mick committed
179
180
    # All the images in the batch should share the same special image
    # bound token ids.
181
182
183
184
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
185
186
    tgt_sizes: Optional[list] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
187
    @staticmethod
188
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
189
190
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
191
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
192
        )
193
194
195

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
196
197
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
198
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
199
200
201
202
203
204
205

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
206
            "images_emb_mask",
207
            "image_spatial_crop",
208
            "im_token_id",
Mick's avatar
Mick committed
209
210
211
212
213
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
214
215
216
217
218
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

219
220
221
222
223
224
225
        # validate
        assert (
            isinstance(ret.pixel_values, torch.Tensor)
            or isinstance(ret.pixel_values, np.ndarray)
            or isinstance(ret.pixel_values, list)
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
226
227
        return ret

228
    def merge(self, other: ImageInputs):
229
230
231
        """
        merge image inputs when requests are being merged
        """
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        if isinstance(self.pixel_values, list):
            # in some rare cases, pixel values are list of patches with different shapes
            # e.g. minicpm
            self.pixel_values += other.pixel_values
        else:
            assert (
                self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
            ), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
            self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

        # args would be stacked along first dim
        # usually these are already tensors
        stack_args = [
            # TODO: merge with image_grid_thws, basically the same thing
            "tgt_sizes",
            "image_spatial_crop",
        ]
        for arg in stack_args:
            if getattr(self, arg, None) is None:
                setattr(self, arg, getattr(other, arg, None))
            elif getattr(other, arg, None) is not None:
                # self and other both not None
                setattr(
                    self,
                    arg,
                    torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
                )

        if self.image_grid_thws is None:
            self.image_grid_thws = other.image_grid_thws
        elif other.image_grid_thws is not None:
            self.image_grid_thws = torch.concat(
                [self.image_grid_thws, other.image_grid_thws]
            )
266

267
268
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
269
270
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
271
272
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
273
        # args needed to be merged
274
275
276
        optional_args = [
            "image_sizes",
            "image_offsets",
277
            "image_pad_len",
278
279
280
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
281
            "images_emb_mask",
282
283
        ]
        for arg in optional_args:
284
285
286
287
            self_arg = getattr(self, arg, None)
            if self_arg is not None:
                setattr(self, arg, self_arg + getattr(other, arg))
        # other args would be kept intact
288

Liangsheng Yin's avatar
Liangsheng Yin committed
289

Lianmin Zheng's avatar
Lianmin Zheng committed
290
class Req:
291
    """The input and output status of a request."""
292

293
294
295
296
297
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
298
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
299
300
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
301
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
302
        stream: bool = False,
303
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
304
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
305
        input_embeds: Optional[List[List[float]]] = None,
306
        session_id: Optional[str] = None,
307
        custom_logit_processor: Optional[str] = None,
308
        return_hidden_states: bool = False,
309
        eos_token_ids: Optional[Set[int]] = None,
310
    ):
311
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
312
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
313
        self.origin_input_text = origin_input_text
314
315
316
317
318
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
319
        self.origin_input_ids = origin_input_ids
320
321
322
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
323
        self.fill_ids = None
324
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
325
        self.input_embeds = input_embeds
326

Lianmin Zheng's avatar
Lianmin Zheng committed
327
        # Sampling info
328
329
330
331
332
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
333
        self.sampling_params = sampling_params
334
        self.custom_logit_processor = custom_logit_processor
335
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
336

337
        # Memory pool info
338
        self.req_pool_idx: Optional[int] = None
339

340
341
342
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
343
344
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
345
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
348
        self.stream = stream
349
        self.eos_token_ids = eos_token_ids
350

351
        # For incremental decoding
352
353
354
355
356
357
358
359
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
360
361
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        self.decoded_text = ""
363

364
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
365
        self.image_inputs: Optional[ImageInputs] = None
366

367
        # Prefix info
368
        # The indices to kv cache for the shared prefix.
369
        self.prefix_indices = []
370
        # Number of tokens to run prefill.
371
        self.extend_input_len = 0
372
373
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
374
        self.last_node = None
375
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
376

377
378
379
380
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
381

382
383
384
        # For retraction
        self.is_retracted = False

385
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
386
        self.return_logprob = return_logprob
387
        # Start index to compute logprob from.
388
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
389
        self.top_logprobs_num = top_logprobs_num
390
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
393

394
        # Logprobs (return values)
395
396
397
398
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
399
400
401
402
403
404
405
406
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
409
410
411
412

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
413
414
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
417
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
418
419
420
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
421
        self.hidden_states: List[List[float]] = []
422

423
        # Embedding (return values)
424
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
425

426
        # Constrained decoding
427
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
428

429
        # The number of cached tokens that were already cached in the KV cache
430
        self.cached_tokens = 0
431
        self.already_computed = 0
432

433
434
435
436
437
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        # For disaggregation
        self.bootstrap_host: str = "0.0.0.0"
        self.bootstrap_room: Optional[int] = None
        self.disagg_kv_sender: Optional[KVSender] = None

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

456
457
458
459
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

460
    def extend_image_inputs(self, image_inputs):
461
462
463
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
464
            self.image_inputs.merge(image_inputs)
465

466
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
467
        # Whether request reached finished condition
468
469
        return self.finished_reason is not None

470
471
472
473
474
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
475
        self.fill_ids = self.origin_input_ids + self.output_ids
476
        if tree_cache is not None:
477
            # tree cache is None if the prefix is not computed with tree cache.
478
479
480
481
482
483
484
485
486
487
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
488
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
489

490
    def adjust_max_prefix_ids(self):
491
492
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
493
494
495
496

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
497
498
499
500
501

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

502
        if self.return_logprob:
503
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
504

505
        max_prefix_len = max(max_prefix_len, 0)
506
        return self.fill_ids[:max_prefix_len]
507

Liangsheng Yin's avatar
Liangsheng Yin committed
508
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
509
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
510
511
512
513
514
515
516
517
518
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
519
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
520

521
    def check_finished(self):
522
        if self.finished():
523
524
            return

525
        if self.to_abort:
526
527
528
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
529
530
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
531
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
532
533
534
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
535
536
            return

537
        last_token_id = self.output_ids[-1]
538

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
556

557
        # Check stop strings
558
559
560
561
562
563
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
564
                if stop_str in tail_str or stop_str in self.decoded_text:
565
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
566
567
                    return

568
569
570
571
572
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
573
574
575
576
577
578
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
579

Lianmin Zheng's avatar
Lianmin Zheng committed
580
    def __repr__(self):
581
        return (
582
583
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
584
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
585
586


587
588
589
bid = 0


590
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
591
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
592
    """Store all information of a batch on the scheduler."""
593

594
    # Request, memory pool, and cache
595
    reqs: List[Req]
596
    req_to_token_pool: ReqToTokenPool = None
597
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
598
    tree_cache: BasePrefixCache = None
599

600
    # Batch configs
601
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
602
    forward_mode: ForwardMode = None
603
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
606
607
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
608
609

    # Sampling info
610
    sampling_info: SamplingBatchInfo = None
611
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
612

613
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
614
    input_ids: torch.Tensor = None  # shape: [b], int64
615
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
616
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
617
    seq_lens: torch.Tensor = None  # shape: [b], int64
618
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
621

622
623
624
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
625
626
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
627
    global_num_tokens_for_logprob: Optional[List[int]] = None
628
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
629

630
    # For processing logprobs
631
    return_logprob: bool = False
632
    top_logprobs_nums: Optional[List[int]] = None
633
    token_ids_logprobs: Optional[List[List[int]]] = None
634

Lianmin Zheng's avatar
Lianmin Zheng committed
635
636
637
638
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

639
640
641
642
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
643
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
644
    extend_logprob_start_lens: List[int] = None
645
646
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
647

Lianmin Zheng's avatar
Lianmin Zheng committed
648
    # For encoder-decoder architectures
649
650
651
652
653
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

654
655
656
    # Stream
    has_stream: bool = False

657
658
    # Has grammar
    has_grammar: bool = False
659

660
    # Device
661
662
    device: str = "cuda"

663
    # Speculative decoding
664
    spec_algorithm: SpeculativeAlgorithm = None
665
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
666

667
668
669
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

670
671
672
    # Whether to return hidden states
    return_hidden_states: bool = False

673
    @classmethod
674
675
    def init_new(
        cls,
676
        reqs: List[Req],
677
        req_to_token_pool: ReqToTokenPool,
678
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
679
680
681
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
682
        spec_algorithm: SpeculativeAlgorithm,
683
        enable_custom_logit_processor: bool,
684
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
        return_logprob = any(req.return_logprob for req in reqs)

687
688
689
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
690
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
691
            tree_cache=tree_cache,
692
            model_config=model_config,
693
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
694
            return_logprob=return_logprob,
695
            has_stream=any(req.stream for req in reqs),
696
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
697
            device=req_to_token_pool.device,
698
            spec_algorithm=spec_algorithm,
699
            enable_custom_logit_processor=enable_custom_logit_processor,
700
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
701
702
        )

703
    def batch_size(self):
704
        return len(self.reqs)
705

Lianmin Zheng's avatar
Lianmin Zheng committed
706
707
708
    def is_empty(self):
        return len(self.reqs) == 0

709
    def alloc_req_slots(self, num_reqs: int):
710
711
712
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
713
714
715
716
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
717
718
719
720
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722
723
724
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

725
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

        return out_cache_loc

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
757

Lianmin Zheng's avatar
Lianmin Zheng committed
758
759
760
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
761
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
        return out_cache_loc

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
782
            if self.tree_cache is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
783
784
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
785
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
786
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
787

Lianmin Zheng's avatar
Lianmin Zheng committed
788
789
790
791
792
793
794
795
796
797
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
798
799
        return out_cache_loc

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

817
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
818
819
820
821
822
823
824
825
826
827
828
829
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
830
                # NOTE: the encoder part should be considered as a whole
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
848
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
849
850
            self.device, non_blocking=True
        )
851
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
852
853
854
855
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
856
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
857
858
859
860
861
862
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
863
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
864
865
866
867
868
869
870
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

871
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
872
873
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
874
        # Allocate req slots
875
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
878
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
879
        reqs = self.reqs
880
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
881
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
882
883
884
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
885

Lianmin Zheng's avatar
Lianmin Zheng committed
886
887
888
889
890
891
892
893
894
895
896
897
898
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
899

Lianmin Zheng's avatar
Lianmin Zheng committed
900
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
901
        input_embeds = []
902
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
903

Lianmin Zheng's avatar
Lianmin Zheng committed
904
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
905
            req.req_pool_idx = req_pool_indices[i]
906
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
907

908
            if pre_len > 0:
909
910
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
911
                )
912

Rin Intachuen's avatar
Rin Intachuen committed
913
914
915
916
917
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

918
919
            if req.is_retracted:
                req.already_computed = 0
920
921
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
922
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
923

924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
978

Lianmin Zheng's avatar
Lianmin Zheng committed
979
980
981
982
983
984
985
986
987
988
989
990
991
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
992
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
993
994
995
996
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
997
998
999
1000
1001
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
1002
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
1003

1004
1005
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
1006
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1007

1008
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
1009
1010
1011
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
1012
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
1013

1014
        # Write to req_to_token_pool
1015
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
1016
1017
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

1018
1019
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
1022
1023
1024
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
1025
1026
1027
1028
1029
1030
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
1031
1032
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
1033
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1034
                pt += extend_lens[i]
1035

1036
1037
1038
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1039
        # Build sampling info
1040
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1041
1042
            self,
            self.model_config.vocab_size,
1043
        )
1044

1045
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1046
        self.forward_mode = ForwardMode.MIXED
1047
        running_bs = running_batch.batch_size()
1048
1049
1050
1051
1052

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1053
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1054
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1055

1056
        self.merge_batch(running_batch)
1057
1058
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1059

1060
1061
1062
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1063
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1064
        self.prefix_lens.extend(
1065
            [
1066
                len(r.origin_input_ids) + len(r.output_ids) + delta
1067
1068
1069
                for r in running_batch.reqs
            ]
        )
1070
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1071
1072
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1073
        self.extend_logprob_start_lens.extend([0] * running_bs)
1074

1075
1076
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
1077
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1078
1079
            return True

Lianmin Zheng's avatar
Lianmin Zheng committed
1080
        self.tree_cache.evict(bs)
1081

1082
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1083
1084
1085
1086
            return True

        return False

1087
    def retract_decode(self, server_args: ServerArgs):
1088
        """Retract the decoding requests when there is not enough memory."""
1089
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1090
1091

        # TODO(lsyin): improve retraction policy for radix cache
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1117

Lianmin Zheng's avatar
Lianmin Zheng committed
1118
1119
1120
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1121
        while (
1122
            self.token_to_kv_pool_allocator.available_size()
1123
            < get_required_tokens(len(sorted_indices))
1124
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1125
1126
1127
1128
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1129
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1130
1131
1132
                ), "No space left for only one request"
                break

1133
            first_iter = False
1134
1135
1136
1137
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1138
1139
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1140
1141
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1142
                ]
1143
                self.token_to_kv_pool_allocator.free(token_indices)
1144
                self.req_to_token_pool.free(req.req_pool_idx)
1145
1146
1147
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
1148
1149
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1150
                ]
1151
                self.token_to_kv_pool_allocator.free(token_indices)
1152
                self.req_to_token_pool.free(req.req_pool_idx)
1153
1154
1155
1156
1157
1158
1159

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1160
                    - self.token_to_kv_pool_allocator.available_size()
1161
1162
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1163
                self.tree_cache.evict(residual_size)
1164

1165
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1166

1167
        self.filter_batch(keep_indices=sorted_indices)
1168

Liangsheng Yin's avatar
Liangsheng Yin committed
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1179

1180
1181
1182
1183
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1184
1185
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1187
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1188
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1189
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1190
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1191
        self.extend_num_tokens = 0
1192
1193
1194
1195
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1196

1197
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1198
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1199
1200
        bs = len(self.reqs)

1201
        if self.spec_algorithm.is_eagle():
1202
1203
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1204
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1229
        # Update fields
1230
1231
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1232

1233
1234
1235
1236
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1237
            locs = self.seq_lens.clone()
1238

1239
        if self.enable_overlap:
1240
1241
1242
1243
1244
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1245
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1246

Lianmin Zheng's avatar
Lianmin Zheng committed
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1262
1263
    def filter_batch(
        self,
1264
        chunked_req_to_exclude: Optional[Req] = None,
1265
1266
1267
1268
1269
1270
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1271
1272
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1273
1274
1275
            ]

        if keep_indices is None or len(keep_indices) == 0:
1276
1277
1278
1279
            # Filter out all requests
            self.reqs = []
            return

1280
        if len(keep_indices) == len(self.reqs):
1281
1282
1283
            # No need to filter
            return

1284
1285
1286
1287
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1288
        if self.model_config.is_encoder_decoder:
1289
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1290
1291
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1292
        self.reqs = [self.reqs[i] for i in keep_indices]
1293
1294
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1295
        self.out_cache_loc = None
1296
        self.seq_lens_sum = self.seq_lens.sum().item()
1297
        self.output_ids = self.output_ids[keep_indices_device]
1298
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1299
        if self.return_logprob:
1300
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1301
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1302
1303
        else:
            self.top_logprobs_nums = None
1304
            self.token_ids_logprobs = None
1305

1306
        self.has_stream = any(req.stream for req in self.reqs)
1307
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1308

1309
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1310
        if self.spec_info:
1311
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1312

1313
    def merge_batch(self, other: "ScheduleBatch"):
1314
1315
1316
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1317
        self.sampling_info.merge_batch(other.sampling_info)
1318

1319
1320
1321
1322
1323
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

1324
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1325
1326
            [self.req_pool_indices, other.req_pool_indices]
        )
1327
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1328
        self.out_cache_loc = None
1329
        self.seq_lens_sum += other.seq_lens_sum
1330
        if self.output_ids is not None:
1331
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1332
1333
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1334
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1335
1336
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1337
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1338
1339
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1340
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1341
        self.reqs.extend(other.reqs)
1342

1343
1344
1345
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1346
        self.return_hidden_states |= other.return_hidden_states
1347

1348
1349
1350
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1351
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1352
        if self.forward_mode.is_decode_or_idle():
lukec's avatar
lukec committed
1353
1354
1355
1356
            if (
                global_server_args_dict["enable_flashinfer_mla"]
                or global_server_args_dict["enable_flashmla"]
            ):
1357
1358
1359
                decode_seq_lens = self.seq_lens.cpu()
            else:
                decode_seq_lens = None
1360
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1361
        else:
1362
            decode_seq_lens = None
1363
1364
1365
1366
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1367
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1368
1369
1370
1371
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1372

1373
1374
        global bid
        bid += 1
1375
        return ModelWorkerBatch(
1376
            bid=bid,
1377
1378
1379
1380
1381
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1382
            seq_lens_sum=self.seq_lens_sum,
1383
1384
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1385
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1386
            global_num_tokens=self.global_num_tokens,
1387
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1388
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1389
            decode_seq_lens=decode_seq_lens,
1390
            extend_num_tokens=self.extend_num_tokens,
1391
1392
1393
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1394
1395
1396
1397
1398
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1399
            lora_paths=[req.lora_path for req in self.reqs],
1400
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1401
            input_embeds=self.input_embeds,
1402
1403
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1404
            capture_hidden_mode=(
1405
                CaptureHiddenMode.FULL
1406
                if self.return_hidden_states
1407
1408
1409
1410
1411
1412
1413
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1414
            ),
1415
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1416
1417
        )

1418
    def copy(self):
1419
        # Only contain fields that will be used by process_batch_result
1420
1421
        return ScheduleBatch(
            reqs=self.reqs,
1422
            model_config=self.model_config,
1423
            forward_mode=self.forward_mode,
1424
1425
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1426
            decoding_reqs=self.decoding_reqs,
1427
            spec_algorithm=self.spec_algorithm,
1428
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1429
1430
1431
1432
1433
1434
1435
1436
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1437

1438
@dataclasses.dataclass
1439
class ModelWorkerBatch:
1440
1441
    # The batch id
    bid: int
1442
1443
1444
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1445
    input_ids: torch.Tensor
1446
1447
1448
1449
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1450
    # The indices of output tokens in the token_to_kv_pool_allocator
1451
1452
    out_cache_loc: torch.Tensor

1453
1454
1455
    # The sum of all sequence lengths
    seq_lens_sum: int

1456
1457
1458
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1459
    token_ids_logprobs: Optional[List[List[int]]]
1460

Ke Bao's avatar
Ke Bao committed
1461
1462
    # For DP attention
    global_num_tokens: Optional[List[int]]
1463
    global_num_tokens_for_logprob: Optional[List[int]]
1464
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1465

1466
1467
1468
    # For decode
    decode_seq_lens: Optional[torch.Tensor]

1469
    # For extend
1470
    extend_num_tokens: Optional[int]
1471
1472
1473
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1474
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1475
1476
1477
1478

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1479
1480
1481
1482
1483
1484
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1485
1486
1487
1488
1489
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1490

Rin Intachuen's avatar
Rin Intachuen committed
1491
1492
1493
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1494
    # Speculative decoding
1495
    spec_algorithm: SpeculativeAlgorithm = None
1496
1497
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1498
    capture_hidden_mode: CaptureHiddenMode = None
1499

1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1518
1519
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1536
1537
1538
1539
1540
1541
1542
1543
1544


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )