"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "3a04aa4be7121543222f0626db066cabccee2a83"
schedule_batch.py 56.3 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import copy
33
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
34
import logging
35
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import torch
39
40
import triton
import triton.language as tl
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.global_config import global_config
43
from sglang.srt.configs.model_config import ModelConfig
44
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
Byron Hsu's avatar
Byron Hsu committed
45
46
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
47
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
48
from sglang.srt.mem_cache.chunk_cache import ChunkCache
49
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
Lianmin Zheng's avatar
Lianmin Zheng committed
50
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
51
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
from sglang.srt.sampling.sampling_params import SamplingParams
53
from sglang.srt.server_args import ServerArgs
54
from sglang.srt.utils import get_compiler_backend
Liangsheng Yin's avatar
Liangsheng Yin committed
55

56
if TYPE_CHECKING:
57
58
59
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
60
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
61

62
63
# Put some global args for easy access
global_server_args_dict = {
64
65
66
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
67
    "disable_mla": ServerArgs.disable_mla,
68
    "torchao_config": ServerArgs.torchao_config,
69
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
70
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
71
    "enable_ep_moe": ServerArgs.enable_ep_moe,
72
    "enable_deepep_moe": ServerArgs.enable_deepep_moe,
73
    "device": ServerArgs.device,
74
75
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
76
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
lukec's avatar
lukec committed
77
    "enable_flashmla": ServerArgs.enable_flashmla,
78
    "disable_radix_cache": ServerArgs.disable_radix_cache,
79
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
80
81
}

Ying Sheng's avatar
Ying Sheng committed
82
83
84
logger = logging.getLogger(__name__)


85
86
87
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
88

89
    def to_json(self):
90
        raise NotImplementedError()
91
92
93


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
94
    def __init__(self, matched: Union[int, List[int]]):
95
96
97
        super().__init__()
        self.matched = matched

98
99
100
101
102
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
103
104


105
106
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
107
        super().__init__()
108
        self.matched = matched
109

110
111
112
113
114
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
115
116


117
118
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
119
        super().__init__()
120
        self.length = length
121

122
123
124
125
126
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
127
128
129


class FINISH_ABORT(BaseFinishReason):
130
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
131
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        self.message = message
133
134
        self.status_code = status_code
        self.err_type = err_type
135

136
137
138
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
139
            "message": self.message,
140
141
            "status_code": self.status_code,
            "err_type": self.err_type,
142
        }
143

Lianmin Zheng's avatar
Lianmin Zheng committed
144

145
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
146
class ImageInputs:
147
148
    """The image related inputs."""

149
    pixel_values: Union[torch.Tensor, np.array]
150
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
151
152
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
153
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
154
155
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
156
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
157

158
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
159
160
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
161

Yineng Zhang's avatar
Yineng Zhang committed
162
163
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
164
    mrope_position_delta: Optional[torch.Tensor] = None
165
166
167
168
    # Qwen2-VL video related
    video_token_id: Optional[int] = None
    video_grid_thws: List[Tuple[int, int, int]] = None
    second_per_grid_ts: Optional[List[torch.Tensor]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
169

170
171
172
173
    # deepseek vl2 related
    image_seq_mask: Optional[List[torch.Tensor]] = None
    image_spatial_crop: Optional[List[torch.Tensor]] = None

174
175
    # The id of the single-image placeholder token
    im_token_id: Optional[torch.Tensor] = None
176

Mick's avatar
Mick committed
177
178
    # All the images in the batch should share the same special image
    # bound token ids.
179
180
181
182
    im_start_id: Optional[int] = None
    im_end_id: Optional[int] = None
    slice_start_id: Optional[int] = None
    slice_end_id: Optional[int] = None
Mick's avatar
Mick committed
183
184
    tgt_sizes: Optional[list] = None

185
186
187
    # denotes the number of valid image tokens in each image
    images_emb_mask: Optional[torch.BoolTensor] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
188
    @staticmethod
189
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
190
191
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
192
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
193
        )
194
195
196

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
197
198
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
199
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
200
201
202
203
204
205
206

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
207
208
            "image_seq_mask",
            "image_spatial_crop",
209
            "im_token_id",
Mick's avatar
Mick committed
210
211
212
213
214
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
215
            "images_emb_mask",
216
217
218
219
220
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
221
222
        return ret

223
    def merge(self, other):
224
225
226
        """
        merge image inputs when requests are being merged
        """
227
228
229
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

230
231
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
232
233
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
234
235
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
236
237
238
239

        optional_args = [
            "image_sizes",
            "image_offsets",
240
            "image_pad_len",
241
242
243
244
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
245
246
            "image_seq_mask",
            "image_spatial_crop",
247
248
249
250
251
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
252

Lianmin Zheng's avatar
Lianmin Zheng committed
253
class Req:
254
    """The input and output status of a request."""
255

256
257
258
259
260
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
261
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
264
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
265
        stream: bool = False,
266
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
267
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
268
        input_embeds: Optional[List[List[float]]] = None,
269
        session_id: Optional[str] = None,
270
        custom_logit_processor: Optional[str] = None,
271
        return_hidden_states: bool = False,
272
        eos_token_ids: Optional[Set[int]] = None,
273
    ):
274
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
275
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
276
        self.origin_input_text = origin_input_text
277
278
279
280
281
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
282
        self.origin_input_ids = origin_input_ids
283
284
285
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
286
        self.fill_ids = None
287
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
288
        self.input_embeds = input_embeds
289

Lianmin Zheng's avatar
Lianmin Zheng committed
290
        # Sampling info
291
292
293
294
295
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
296
        self.sampling_params = sampling_params
297
        self.custom_logit_processor = custom_logit_processor
298
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
299

300
        # Memory pool info
301
        self.req_pool_idx: Optional[int] = None
302

303
304
305
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
306
307
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
308
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
311
        self.stream = stream
312
        self.eos_token_ids = eos_token_ids
313

314
        # For incremental decoding
315
316
317
318
319
320
321
322
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
323
324
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
325
        self.decoded_text = ""
326

327
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
328
        self.image_inputs: Optional[ImageInputs] = None
329

330
        # Prefix info
331
        # The indices to kv cache for the shared prefix.
332
        self.prefix_indices = []
333
        # Number of tokens to run prefill.
334
        self.extend_input_len = 0
335
336
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
337
        self.last_node = None
338
        self.last_node_global = None
Lianmin Zheng's avatar
Lianmin Zheng committed
339

340
341
342
343
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
344

345
346
347
        # For retraction
        self.is_retracted = False

348
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
349
        self.return_logprob = return_logprob
350
        # Start index to compute logprob from.
351
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
352
        self.top_logprobs_num = top_logprobs_num
353
        self.token_ids_logprob = token_ids_logprob
Lianmin Zheng's avatar
Lianmin Zheng committed
354
355
        self.temp_scaled_logprobs = False
        self.top_p_normalized_logprobs = False
356

357
        # Logprobs (return values)
358
359
360
361
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
362
363
364
365
366
367
368
369
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
372
373
374
375

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
376
377
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
380
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
381
382
383
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
384
        self.hidden_states: List[List[float]] = []
385

386
        # Embedding (return values)
387
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
388

389
        # Constrained decoding
390
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
391

392
        # The number of cached tokens that were already cached in the KV cache
393
        self.cached_tokens = 0
394
        self.already_computed = 0
395

396
397
398
399
400
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

Byron Hsu's avatar
Byron Hsu committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        # For disaggregation
        self.bootstrap_host: str = "0.0.0.0"
        self.bootstrap_room: Optional[int] = None
        self.disagg_kv_sender: Optional[KVSender] = None

        # used for warmup because we don't have a pair yet when init
        self.skip_kv_transfer: bool = False
        # the start index of the sent kv cache
        # We want to send it chunk by chunk for chunked prefill.
        # After every chunk forward, we do the following:
        # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
        # start_send_idx = len(req.fill_ids)
        self.start_send_idx: int = 0

        self.metadata_buffer_index: int = -1
        # The first output_id transferred from prefill instance.
        self.transferred_output_id: Optional[int] = None

419
420
421
422
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

423
    def extend_image_inputs(self, image_inputs):
424
425
426
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
427
            self.image_inputs.merge(image_inputs)
428

429
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
430
        # Whether request reached finished condition
431
432
        return self.finished_reason is not None

433
434
435
436
437
    def init_next_round_input(
        self,
        tree_cache: Optional[BasePrefixCache] = None,
        enable_hierarchical_cache=False,
    ):
438
        self.fill_ids = self.origin_input_ids + self.output_ids
439
        if tree_cache is not None:
440
            # tree cache is None if the prefix is not computed with tree cache.
441
442
443
444
445
446
447
448
449
450
            if enable_hierarchical_cache:
                self.prefix_indices, self.last_node, self.last_node_global = (
                    tree_cache.match_prefix(
                        key=self.adjust_max_prefix_ids(), include_evicted=True
                    )
                )
            else:
                self.prefix_indices, self.last_node = tree_cache.match_prefix(
                    rid=self.rid, key=self.adjust_max_prefix_ids()
                )
451
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
452

453
    def adjust_max_prefix_ids(self):
454
455
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
456
457
458
459

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
460
461
462
463
464

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

465
        if self.return_logprob:
466
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
467

468
        max_prefix_len = max(max_prefix_len, 0)
469
        return self.fill_ids[:max_prefix_len]
470

Liangsheng Yin's avatar
Liangsheng Yin committed
471
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
472
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
473
474
475
476
477
478
479
480
481
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
482
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
483

484
    def check_finished(self):
485
        if self.finished():
486
487
            return

488
        if self.to_abort:
489
490
491
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
492
493
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
494
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
495
496
497
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
498
499
            return

500
        last_token_id = self.output_ids[-1]
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
519

520
        # Check stop strings
521
522
523
524
525
526
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
527
                if stop_str in tail_str or stop_str in self.decoded_text:
528
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
529
530
                    return

531
532
533
534
535
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
536
537
538
539
540
541
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
542

Lianmin Zheng's avatar
Lianmin Zheng committed
543
    def __repr__(self):
544
        return (
545
546
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
547
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
548
549


550
551
552
bid = 0


553
@dataclasses.dataclass
Byron Hsu's avatar
Byron Hsu committed
554
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
555
    """Store all information of a batch on the scheduler."""
556

557
    # Request, memory pool, and cache
558
    reqs: List[Req]
559
    req_to_token_pool: ReqToTokenPool = None
560
    token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
561
    tree_cache: BasePrefixCache = None
562

563
    # Batch configs
564
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
565
    forward_mode: ForwardMode = None
566
    enable_overlap: bool = False
Lianmin Zheng's avatar
Lianmin Zheng committed
567
568
569
570
    # Tell whether the current running batch is full so that we can skip
    # the check of whether to prefill new requests.
    # This is an optimization to reduce the overhead of the prefill check.
    batch_is_full: bool = False
571
572

    # Sampling info
573
    sampling_info: SamplingBatchInfo = None
574
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
575

576
    # Batched arguments to model runner
Lianmin Zheng's avatar
Lianmin Zheng committed
577
    input_ids: torch.Tensor = None  # shape: [b], int64
578
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
Lianmin Zheng's avatar
Lianmin Zheng committed
579
    req_pool_indices: torch.Tensor = None  # shape: [b], int64
580
    seq_lens: torch.Tensor = None  # shape: [b], int64
581
    # The output locations of the KV cache
Lianmin Zheng's avatar
Lianmin Zheng committed
582
583
    out_cache_loc: torch.Tensor = None  # shape: [b], int64
    output_ids: torch.Tensor = None  # shape: [b], int64
584

585
586
587
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
588
589
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
590
    global_num_tokens_for_logprob: Optional[List[int]] = None
591
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
592

593
    # For processing logprobs
594
    return_logprob: bool = False
595
    top_logprobs_nums: Optional[List[int]] = None
596
    token_ids_logprobs: Optional[List[List[int]]] = None
597

Lianmin Zheng's avatar
Lianmin Zheng committed
598
599
600
601
    # For logits and logprob post processing
    temp_scaled_logprobs: bool = False
    top_p_normalized_logprobs: bool = False

602
603
604
605
    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
606
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
607
    extend_logprob_start_lens: List[int] = None
608
609
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
610

Lianmin Zheng's avatar
Lianmin Zheng committed
611
    # For encoder-decoder architectures
612
613
614
615
616
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

617
618
619
    # Stream
    has_stream: bool = False

620
621
    # Has grammar
    has_grammar: bool = False
622

623
    # Device
624
625
    device: str = "cuda"

626
    # Speculative decoding
627
    spec_algorithm: SpeculativeAlgorithm = None
628
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
629

630
631
632
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

633
634
635
    # Whether to return hidden states
    return_hidden_states: bool = False

636
    @classmethod
637
638
    def init_new(
        cls,
639
        reqs: List[Req],
640
        req_to_token_pool: ReqToTokenPool,
641
        token_to_kv_pool_allocator: TokenToKVPoolAllocator,
642
643
644
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
645
        spec_algorithm: SpeculativeAlgorithm,
646
        enable_custom_logit_processor: bool,
647
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
648
649
        return_logprob = any(req.return_logprob for req in reqs)

650
651
652
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
653
            token_to_kv_pool_allocator=token_to_kv_pool_allocator,
654
            tree_cache=tree_cache,
655
            model_config=model_config,
656
            enable_overlap=enable_overlap,
Lianmin Zheng's avatar
Lianmin Zheng committed
657
            return_logprob=return_logprob,
658
            has_stream=any(req.stream for req in reqs),
659
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
660
            device=req_to_token_pool.device,
661
            spec_algorithm=spec_algorithm,
662
            enable_custom_logit_processor=enable_custom_logit_processor,
663
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
        )

666
    def batch_size(self):
667
        return len(self.reqs)
668

Lianmin Zheng's avatar
Lianmin Zheng committed
669
670
671
    def is_empty(self):
        return len(self.reqs) == 0

672
    def alloc_req_slots(self, num_reqs: int):
673
674
675
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
676
677
678
679
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
680
681
682
683
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
Lianmin Zheng's avatar
Lianmin Zheng committed
684
685
686
687
        if self.token_to_kv_pool_allocator.available_size() < num_tokens:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens)

688
        out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
        if out_cache_loc is None:
            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
            error_msg = (
                f"{phase_str} out of memory. Try to lower your batch size.\n"
                f"Try to allocate {num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
            )
            logger.error(error_msg)
            if self.tree_cache is not None:
                self.tree_cache.pretty_print()
            raise RuntimeError(error_msg)

        return out_cache_loc

    def alloc_paged_token_slots_extend(
        self,
        prefix_lens: torch.Tensor,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
        extend_num_tokens: int,
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < extend_num_tokens
            + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
            if self.tree_cache is not None:
                self.tree_cache.evict(
                    extend_num_tokens
                    + len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
                )
720

Lianmin Zheng's avatar
Lianmin Zheng committed
721
722
723
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
            prefix_lens, seq_lens, last_loc, extend_num_tokens
        )
724
        if out_cache_loc is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
            error_msg = (
                f"Prefill out of memory. Try to lower your batch size.\n"
                f"Try to allocate {extend_num_tokens} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
        return out_cache_loc

    def alloc_paged_token_slots_decode(
        self,
        seq_lens: torch.Tensor,
        last_loc: torch.Tensor,
    ):
        if (
            self.token_to_kv_pool_allocator.available_size()
            < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
        ):
745
            if self.tree_cache is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
746
747
                self.tree_cache.evict(
                    len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
748
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
749
        out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
754
755
756
757
758
759
760
        if out_cache_loc is None:
            error_msg = (
                f"Decode out of memory. Try to lower your batch size.\n"
                f"Try to allocate {len(seq_lens)} tokens.\n"
                f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
            )
            logger.error(error_msg)
            raise RuntimeError(error_msg)
761
762
        return out_cache_loc

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

780
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
781
782
783
784
785
786
787
788
789
790
791
792
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
793
                # NOTE: the encoder part should be considered as a whole
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
Lianmin Zheng's avatar
Lianmin Zheng committed
811
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
812
813
            self.device, non_blocking=True
        )
814
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
815
816
817
818
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
819
            self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
820
821
822
823
824
825
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
Lianmin Zheng's avatar
Lianmin Zheng committed
826
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
827
828
829
830
831
832
833
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

834
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
835
836
        self.forward_mode = ForwardMode.EXTEND

Lianmin Zheng's avatar
Lianmin Zheng committed
837
        # Allocate req slots
838
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
841
        req_pool_indices = self.alloc_req_slots(bs)

        # Init tensors
Lianmin Zheng's avatar
Lianmin Zheng committed
842
        reqs = self.reqs
843
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
844
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
845
846
847
        seq_lens = [len(r.fill_ids) for r in reqs]
        prefix_lens = [len(r.prefix_indices) for r in reqs]
        extend_lens = [r.extend_input_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
848

Lianmin Zheng's avatar
Lianmin Zheng committed
849
850
851
852
853
854
855
856
857
858
859
860
861
        req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        input_ids_tensor = torch.tensor(sum(input_ids, []), dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
            self.device, non_blocking=True
        )
        prefix_lens_tensor = torch.tensor(
            prefix_lens, dtype=torch.int64, device=self.device
        )
        extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
862

Lianmin Zheng's avatar
Lianmin Zheng committed
863
        # Copy prefix and do some basic check
Rin Intachuen's avatar
Rin Intachuen committed
864
        input_embeds = []
865
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
866

Lianmin Zheng's avatar
Lianmin Zheng committed
867
        for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
868
            req.req_pool_idx = req_pool_indices[i]
869
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
870

871
            if pre_len > 0:
872
873
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
874
                )
875

Rin Intachuen's avatar
Rin Intachuen committed
876
877
878
879
880
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

881
882
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
883
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
884

885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
939

Lianmin Zheng's avatar
Lianmin Zheng committed
940
941
942
943
944
945
946
947
948
949
950
951
952
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            out_cache_loc = self.alloc_token_slots(extend_num_tokens)
        else:
            last_loc = get_last_loc(
                self.req_to_token_pool.req_to_token,
                req_pool_indices_tensor,
                prefix_lens_tensor,
            )
            out_cache_loc = self.alloc_paged_token_slots_extend(
                prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
953
        # Set fields
Lianmin Zheng's avatar
Lianmin Zheng committed
954
955
956
957
        self.input_ids = input_ids_tensor
        self.req_pool_indices = req_pool_indices_tensor
        self.seq_lens = seq_lens_tensor
        self.out_cache_loc = out_cache_loc
Rin Intachuen's avatar
Rin Intachuen committed
958
959
960
961
962
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )
963
        self.seq_lens_sum = sum(seq_lens)
Lianmin Zheng's avatar
Lianmin Zheng committed
964

965
966
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
967
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
968

969
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
970
971
972
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = prefix_lens
        self.extend_lens = extend_lens
973
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
974

975
        # Write to req_to_token_pool
976
        if global_server_args_dict["attention_backend"] != "torch_native":
Lianmin Zheng's avatar
Lianmin Zheng committed
977
978
            # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

979
980
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
Lianmin Zheng's avatar
Lianmin Zheng committed
981
982
983
984
985
                req_pool_indices_tensor,
                prefix_lens_tensor,
                seq_lens_tensor,
                extend_lens_tensor,
                out_cache_loc,
986
987
988
989
990
991
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
Lianmin Zheng's avatar
Lianmin Zheng committed
992
993
                    (req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
                    out_cache_loc[pt : pt + extend_lens[i]],
994
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
995
                pt += extend_lens[i]
996

997
998
999
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

1000
        # Build sampling info
1001
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1002
1003
            self,
            self.model_config.vocab_size,
1004
        )
1005

1006
    def mix_with_running(self, running_batch: "ScheduleBatch"):
1007
        self.forward_mode = ForwardMode.MIXED
1008
        running_bs = running_batch.batch_size()
1009
1010
1011
1012
1013

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

1014
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
1015
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
1016

1017
        self.merge_batch(running_batch)
1018
1019
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
1020

1021
1022
1023
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

1024
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
1025
        self.prefix_lens.extend(
1026
            [
1027
                len(r.origin_input_ids) + len(r.output_ids) + delta
1028
1029
1030
                for r in running_batch.reqs
            ]
        )
1031
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
1033
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
1034
        self.extend_logprob_start_lens.extend([0] * running_bs)
1035

1036
1037
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
1038
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1039
1040
            return True

Lianmin Zheng's avatar
Lianmin Zheng committed
1041
        self.tree_cache.evict(bs)
1042

1043
        if self.token_to_kv_pool_allocator.available_size() >= bs:
1044
1045
1046
1047
            return True

        return False

1048
    def retract_decode(self, server_args: ServerArgs):
1049
        """Retract the decoding requests when there is not enough memory."""
1050
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
1051
1052

        # TODO(lsyin): improve retraction policy for radix cache
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
1078

Lianmin Zheng's avatar
Lianmin Zheng committed
1079
1080
1081
        retracted_reqs = []
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1082
        while (
1083
            self.token_to_kv_pool_allocator.available_size()
1084
            < get_required_tokens(len(sorted_indices))
1085
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
1086
1087
1088
1089
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
1090
                    self.token_to_kv_pool_allocator.available_size() > 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1091
1092
1093
                ), "No space left for only one request"
                break

1094
            first_iter = False
1095
1096
1097
1098
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

1099
1100
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
1101
1102
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
1103
                ]
1104
                self.token_to_kv_pool_allocator.free(token_indices)
1105
                self.req_to_token_pool.free(req.req_pool_idx)
1106
1107
1108
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
1109
1110
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1111
                ]
1112
                self.token_to_kv_pool_allocator.free(token_indices)
1113
                self.req_to_token_pool.free(req.req_pool_idx)
1114
1115
1116
1117
1118
1119
1120

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
1121
                    - self.token_to_kv_pool_allocator.available_size()
1122
1123
                )
                residual_size = max(0, residual_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
                self.tree_cache.evict(residual_size)
1125

1126
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1127

1128
        self.filter_batch(keep_indices=sorted_indices)
1129

Liangsheng Yin's avatar
Liangsheng Yin committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1140

1141
1142
1143
1144
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1145
1146
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
        self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1148
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
        self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1150
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1151
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1152
        self.extend_num_tokens = 0
1153
1154
1155
1156
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1157

1158
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1159
        self.forward_mode = ForwardMode.DECODE
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
1161
        bs = len(self.reqs)

1162
        if self.spec_algorithm.is_eagle():
1163
1164
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1165
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1166

1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
1190
        # Update fields
1191
1192
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1193

1194
1195
1196
1197
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1198
            locs = self.seq_lens.clone()
1199

1200
        if self.enable_overlap:
1201
1202
1203
1204
1205
            # Do not use in-place operations in the overlap mode
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.seq_lens.add_(1)
1206
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1207

Lianmin Zheng's avatar
Lianmin Zheng committed
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
        # Allocate memory
        if self.token_to_kv_pool_allocator.page_size == 1:
            self.out_cache_loc = self.alloc_token_slots(bs)
        else:
            last_loc = self.req_to_token_pool.req_to_token[
                self.req_pool_indices, self.seq_lens - 2
            ]
            self.out_cache_loc = self.alloc_paged_token_slots_decode(
                self.seq_lens, last_loc
            )

        self.req_to_token_pool.write(
            (self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
        )

1223
1224
    def filter_batch(
        self,
1225
        chunked_req_to_exclude: Optional[Req] = None,
1226
1227
1228
1229
1230
1231
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1232
1233
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1234
1235
1236
            ]

        if keep_indices is None or len(keep_indices) == 0:
1237
1238
1239
1240
            # Filter out all requests
            self.reqs = []
            return

1241
        if len(keep_indices) == len(self.reqs):
1242
1243
1244
            # No need to filter
            return

1245
1246
1247
1248
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1249
        if self.model_config.is_encoder_decoder:
1250
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1251
1252
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1253
        self.reqs = [self.reqs[i] for i in keep_indices]
1254
1255
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1256
        self.out_cache_loc = None
1257
        self.seq_lens_sum = self.seq_lens.sum().item()
1258
        self.output_ids = self.output_ids[keep_indices_device]
1259
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1260
        if self.return_logprob:
1261
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1262
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1263
1264
        else:
            self.top_logprobs_nums = None
1265
            self.token_ids_logprobs = None
1266

1267
        self.has_stream = any(req.stream for req in self.reqs)
1268
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1269

1270
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1271
        if self.spec_info:
1272
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1273

1274
    def merge_batch(self, other: "ScheduleBatch"):
1275
1276
1277
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1278
        self.sampling_info.merge_batch(other.sampling_info)
1279

1280
1281
1282
1283
1284
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

1285
        self.req_pool_indices = torch.cat(
Lianmin Zheng's avatar
Lianmin Zheng committed
1286
1287
            [self.req_pool_indices, other.req_pool_indices]
        )
1288
        self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1289
        self.out_cache_loc = None
1290
        self.seq_lens_sum += other.seq_lens_sum
1291
        if self.output_ids is not None:
1292
            self.output_ids = torch.cat([self.output_ids, other.output_ids])
1293
1294
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1295
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1296
1297
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1298
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1299
1300
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1301
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1302
        self.reqs.extend(other.reqs)
1303

1304
1305
1306
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1307
        self.return_hidden_states |= other.return_hidden_states
1308

1309
1310
1311
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1312
    def get_model_worker_batch(self) -> ModelWorkerBatch:
1313
        if self.forward_mode.is_decode_or_idle():
lukec's avatar
lukec committed
1314
1315
1316
1317
            if (
                global_server_args_dict["enable_flashinfer_mla"]
                or global_server_args_dict["enable_flashmla"]
            ):
1318
1319
1320
                decode_seq_lens = self.seq_lens.cpu()
            else:
                decode_seq_lens = None
1321
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1322
        else:
1323
            decode_seq_lens = None
1324
1325
1326
1327
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1328
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1329
1330
1331
1332
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1333

1334
1335
        global bid
        bid += 1
1336
        return ModelWorkerBatch(
1337
            bid=bid,
1338
1339
1340
1341
1342
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1343
            seq_lens_sum=self.seq_lens_sum,
1344
1345
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1346
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1347
            global_num_tokens=self.global_num_tokens,
1348
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1349
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1350
            decode_seq_lens=decode_seq_lens,
1351
            extend_num_tokens=self.extend_num_tokens,
1352
1353
1354
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1355
1356
1357
1358
1359
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1360
            lora_paths=[req.lora_path for req in self.reqs],
1361
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1362
            input_embeds=self.input_embeds,
1363
1364
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1365
            capture_hidden_mode=(
1366
                CaptureHiddenMode.FULL
1367
                if self.return_hidden_states
1368
1369
1370
1371
1372
1373
1374
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1375
            ),
1376
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1377
1378
        )

1379
    def copy(self):
1380
        # Only contain fields that will be used by process_batch_result
1381
1382
        return ScheduleBatch(
            reqs=self.reqs,
1383
            model_config=self.model_config,
1384
            forward_mode=self.forward_mode,
1385
1386
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1387
            decoding_reqs=self.decoding_reqs,
1388
            spec_algorithm=self.spec_algorithm,
1389
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1390
1391
1392
1393
1394
1395
1396
1397
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1398

1399
@dataclasses.dataclass
1400
class ModelWorkerBatch:
1401
1402
    # The batch id
    bid: int
1403
1404
1405
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1406
    input_ids: torch.Tensor
1407
1408
1409
1410
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
1411
    # The indices of output tokens in the token_to_kv_pool_allocator
1412
1413
    out_cache_loc: torch.Tensor

1414
1415
1416
    # The sum of all sequence lengths
    seq_lens_sum: int

1417
1418
1419
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1420
    token_ids_logprobs: Optional[List[List[int]]]
1421

Ke Bao's avatar
Ke Bao committed
1422
1423
    # For DP attention
    global_num_tokens: Optional[List[int]]
1424
    global_num_tokens_for_logprob: Optional[List[int]]
1425
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1426

1427
1428
1429
    # For decode
    decode_seq_lens: Optional[torch.Tensor]

1430
    # For extend
1431
    extend_num_tokens: Optional[int]
1432
1433
1434
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1435
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1436
1437
1438
1439

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1440
1441
1442
1443
1444
1445
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1446
1447
1448
1449
1450
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1451

Rin Intachuen's avatar
Rin Intachuen committed
1452
1453
1454
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1455
    # Speculative decoding
1456
    spec_algorithm: SpeculativeAlgorithm = None
1457
1458
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1459
    capture_hidden_mode: CaptureHiddenMode = None
1460

1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

Lianmin Zheng's avatar
Lianmin Zheng committed
1479
1480
    # NOTE: This can be slow for large bs
    cumsum_start = tl.cast(0, tl.int64)
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1497
1498
1499
1500
1501
1502
1503
1504
1505


@torch.compile(dynamic=True, backend=get_compiler_backend())
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
    return torch.where(
        prefix_lens_tensor > 0,
        req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
        torch.full_like(prefix_lens_tensor, -1),
    )