schedule_batch.py 50.5 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import copy
33
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
34
import logging
35
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
36

37
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import torch
39
40
import triton
import triton.language as tl
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
from sglang.global_config import global_config
43
from sglang.srt.configs.model_config import ModelConfig
44
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
45
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
46
from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Lianmin Zheng's avatar
Lianmin Zheng committed
48
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
49
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
50
from sglang.srt.sampling.sampling_params import SamplingParams
51
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
52

53
if TYPE_CHECKING:
54
55
56
57
    from sglang.srt.server_args import ServerArgs
    from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
    from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

58

Liangsheng Yin's avatar
Liangsheng Yin committed
59
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
60

61
62
# Put some global args for easy access
global_server_args_dict = {
63
64
65
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
66
    "disable_mla": ServerArgs.disable_mla,
67
    "torchao_config": ServerArgs.torchao_config,
68
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
69
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
70
    "enable_ep_moe": ServerArgs.enable_ep_moe,
71
    "device": ServerArgs.device,
72
73
    "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
    "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
74
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
75
    "disable_radix_cache": ServerArgs.disable_radix_cache,
76
    "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
77
78
}

Ying Sheng's avatar
Ying Sheng committed
79
80
81
logger = logging.getLogger(__name__)


82
83
84
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
85

86
    def to_json(self):
87
        raise NotImplementedError()
88
89
90


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
91
    def __init__(self, matched: Union[int, List[int]]):
92
93
94
        super().__init__()
        self.matched = matched

95
96
97
98
99
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
100
101


102
103
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
104
        super().__init__()
105
        self.matched = matched
106

107
108
109
110
111
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
112
113


114
115
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
116
        super().__init__()
117
        self.length = length
118

119
120
121
122
123
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
124
125
126


class FINISH_ABORT(BaseFinishReason):
127
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
128
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
129
        self.message = message
130
131
        self.status_code = status_code
        self.err_type = err_type
132

133
134
135
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
136
            "message": self.message,
137
138
            "status_code": self.status_code,
            "err_type": self.err_type,
139
        }
140

Lianmin Zheng's avatar
Lianmin Zheng committed
141

142
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
143
class ImageInputs:
144
145
    """The image related inputs."""

146
    pixel_values: Union[torch.Tensor, np.array]
147
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
148
149
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
150
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
151
152
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
153
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
154

155
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
158

Yineng Zhang's avatar
Yineng Zhang committed
159
160
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
161
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
162

Mick's avatar
Mick committed
163
164
165
166
167
168
169
170
171
    # MiniCPMV related
    # All the images in the batch should share the same special image
    # bound token ids.
    im_start_id: Optional[torch.Tensor] = None
    im_end_id: Optional[torch.Tensor] = None
    slice_start_id: Optional[torch.Tensor] = None
    slice_end_id: Optional[torch.Tensor] = None
    tgt_sizes: Optional[list] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
172
    @staticmethod
173
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
174
175
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
176
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
177
        )
178
179
180

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
181
182
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
183
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
184
185
186
187
188
189
190

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
Mick's avatar
Mick committed
191
192
193
194
195
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
196
197
198
199
200
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
201
202
        return ret

203
    def merge(self, other):
204
205
206
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

207
208
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
209
210
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
211
212
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
213
214
215
216

        optional_args = [
            "image_sizes",
            "image_offsets",
217
            "image_pad_len",
218
219
220
221
222
223
224
225
226
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
227

Lianmin Zheng's avatar
Lianmin Zheng committed
228
class Req:
229
    """The input and output status of a request."""
230

231
232
233
234
235
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
236
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
239
        token_ids_logprob: List[int] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        stream: bool = False,
241
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
242
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
243
        input_embeds: Optional[List[List[float]]] = None,
244
        session_id: Optional[str] = None,
245
        custom_logit_processor: Optional[str] = None,
246
        return_hidden_states: bool = False,
247
        eos_token_ids: Optional[Set[int]] = None,
248
    ):
249
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
250
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
251
        self.origin_input_text = origin_input_text
252
253
254
255
256
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
257
        self.origin_input_ids = origin_input_ids
258
259
260
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
261
        self.fill_ids = None
262
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
263
        self.input_embeds = input_embeds
264

Lianmin Zheng's avatar
Lianmin Zheng committed
265
        # Sampling info
266
267
268
269
270
        if isinstance(sampling_params.custom_params, dict):
            sampling_params = copy.copy(sampling_params)
            sampling_params.custom_params = sampling_params.custom_params | {
                "__req__": self
            }
271
        self.sampling_params = sampling_params
272

273
        self.custom_logit_processor = custom_logit_processor
274
        self.return_hidden_states = return_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
275

276
        # Memory pool info
277
        self.req_pool_idx: Optional[int] = None
278

279
280
281
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
282
283
        # If we want to abort the request in the middle of the event loop, set this to true
        # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
284
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
        # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
        self.to_abort_message: str = "Unknown error"
Lianmin Zheng's avatar
Lianmin Zheng committed
287
        self.stream = stream
288
        self.eos_token_ids = eos_token_ids
289

290
        # For incremental decoding
291
292
293
294
295
296
297
298
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
Liangsheng Yin's avatar
Liangsheng Yin committed
299
300
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
301
        self.decoded_text = ""
302

303
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
304
        self.image_inputs: Optional[ImageInputs] = None
305

306
        # Prefix info
307
        # The indices to kv cache for the shared prefix.
308
        self.prefix_indices = []
309
        # Number of tokens to run prefill.
310
        self.extend_input_len = 0
311
312
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0
313
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
314

315
316
317
318
        # Whether or not if it is chunked. It increments whenever
        # it is chunked, and decrement whenever chunked request is
        # processed.
        self.is_chunked = 0
319

320
321
322
        # For retraction
        self.is_retracted = False

323
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
324
        self.return_logprob = return_logprob
325
        # Start index to compute logprob from.
326
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
327
        self.top_logprobs_num = top_logprobs_num
328
        self.token_ids_logprob = token_ids_logprob
329

330
        # Logprobs (return values)
331
332
333
334
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
335
336
337
338
339
340
341
342
        self.input_token_ids_logprobs_val: Optional[List[float]] = None
        self.input_token_ids_logprobs_idx: Optional[List[int]] = None
        # Temporary holder to store input_token_logprobs.
        self.input_token_logprobs: Optional[List[Tuple[int]]] = None
        self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
        self.temp_input_top_logprobs_idx: Optional[List[int]] = None
        self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
        self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
343
344
345
346
347
348

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
349
350
            self.output_token_ids_logprobs_val = []
            self.output_token_ids_logprobs_idx = []
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
353
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
354
355
356
            ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
                self.output_token_ids_logprobs_idx
            ) = None
357
        self.hidden_states = []
358

359
        # Embedding (return values)
360
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
361

362
        # Constrained decoding
363
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
364

365
        # The number of cached tokens that were already cached in the KV cache
366
        self.cached_tokens = 0
367
        self.already_computed = 0
368

369
370
371
372
373
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

374
375
376
377
    @property
    def seqlen(self):
        return len(self.origin_input_ids) + len(self.output_ids)

378
    def extend_image_inputs(self, image_inputs):
379
380
381
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
382
            self.image_inputs.merge(image_inputs)
383

384
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
385
        # Whether request reached finished condition
386
387
        return self.finished_reason is not None

388
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
389
        self.fill_ids = self.origin_input_ids + self.output_ids
390
        if tree_cache is not None:
391
            # tree cache is None if the prefix is not computed with tree cache.
392
393
394
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
395
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
396

397
    def adjust_max_prefix_ids(self):
398
399
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
400
401
402
403

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
404
405
406
407
408

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

409
        if self.return_logprob:
410
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
411

412
        max_prefix_len = max(max_prefix_len, 0)
413
        return self.fill_ids[:max_prefix_len]
414

Liangsheng Yin's avatar
Liangsheng Yin committed
415
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
416
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
417
418
419
420
421
422
423
424
425
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
426
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
427

428
    def get_next_inc_detokenization(self):
429
430
        if self.tokenizer is None:
            return False, ""
431
432
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
433
434
435
436
437

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
438
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
439
440
441
442
443
444
445
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
446
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
447
448

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
449

450
    def check_finished(self):
451
        if self.finished():
452
453
            return

454
        if self.to_abort:
455
456
457
            self.finished_reason = FINISH_ABORT(
                message=self.to_abort_message,
            )
458
459
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
460
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
461
462
463
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
464
465
            return

466
        last_token_id = self.output_ids[-1]
467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
485

486
        # Check stop strings
487
488
489
490
491
492
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
493
                if stop_str in tail_str or stop_str in self.decoded_text:
494
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
495
496
                    return

497
498
499
500
501
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True
502
503
504
505
506
507
        self.input_token_logprobs = None
        self.temp_input_top_logprobs_val = None
        self.temp_input_top_logprobs_idx = None
        self.extend_logprob_start_len = 0
        self.is_chunked = 0
        self.req_pool_idx = None
508

Lianmin Zheng's avatar
Lianmin Zheng committed
509
    def __repr__(self):
510
        return (
511
512
            f"Req(rid={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
513
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
514
515


516
517
518
bid = 0


519
@dataclasses.dataclass
520
class ScheduleBatch:
521
    """Store all information of a batch on the scheduler."""
522

523
    # Request, memory pool, and cache
524
    reqs: List[Req]
525
526
527
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
528

529
    # Batch configs
530
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
531
    forward_mode: ForwardMode = None
532
533
534
    enable_overlap: bool = False

    # Sampling info
535
    sampling_info: SamplingBatchInfo = None
536
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
537

538
    # Batched arguments to model runner
539
540
541
542
    input_ids: torch.Tensor = None  # shape: [b], int32
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
    req_pool_indices: torch.Tensor = None  # shape: [b], int32
    seq_lens: torch.Tensor = None  # shape: [b], int64
543
    # The output locations of the KV cache
544
545
    out_cache_loc: torch.Tensor = None  # shape: [b], int32
    output_ids: torch.Tensor = None  # shape: [b], int32
546

547
548
549
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
550
551
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
552
    global_num_tokens_for_logprob: Optional[List[int]] = None
553
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
554

555
    # For processing logprobs
556
    return_logprob: bool = False
557
    top_logprobs_nums: Optional[List[int]] = None
558
    token_ids_logprobs: Optional[List[List[int]]] = None
559
560
561
562
563

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
564
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
565
    extend_logprob_start_lens: List[int] = None
566
567
    # It comes empty list if logprob is not required.
    extend_input_logprob_token_ids: Optional[torch.Tensor] = None
568

569
570
571
572
573
574
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

575
576
577
    # Stream
    has_stream: bool = False

578
579
    # Has grammar
    has_grammar: bool = False
580

581
    # Device
582
583
    device: str = "cuda"

584
    # Speculative decoding
585
    spec_algorithm: SpeculativeAlgorithm = None
586
    spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
587

588
589
590
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

591
592
593
    # Whether to return hidden states
    return_hidden_states: bool = False

594
    @classmethod
595
596
    def init_new(
        cls,
597
        reqs: List[Req],
598
599
600
601
602
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
603
        spec_algorithm: SpeculativeAlgorithm,
604
        enable_custom_logit_processor: bool,
605
    ):
606
607
608
609
610
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
611
            model_config=model_config,
612
            enable_overlap=enable_overlap,
613
614
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
615
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
616
            device=req_to_token_pool.device,
617
            spec_algorithm=spec_algorithm,
618
            enable_custom_logit_processor=enable_custom_logit_processor,
619
            return_hidden_states=any(req.return_hidden_states for req in reqs),
Lianmin Zheng's avatar
Lianmin Zheng committed
620
621
        )

622
    def batch_size(self):
623
        return len(self.reqs)
624

Lianmin Zheng's avatar
Lianmin Zheng committed
625
626
627
    def is_empty(self):
        return len(self.reqs) == 0

628
    def alloc_req_slots(self, num_reqs: int):
629
630
631
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
632
633
634
635
                "alloc_req_slots runs out of memory. "
                "Please set a smaller number for `--max-running-requests`. "
                f"{self.req_to_token_pool.available_size()=}, "
                f"{num_reqs=}, "
636
637
638
639
640
641
642
643
644
645
646
647
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
648
649
650
651
652
653
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
654
655
656
657
658
659
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

677
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
678
679
680
681
682
683
684
685
686
687
688
689
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
690
                # NOTE: the encoder part should be considered as a whole
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
711
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
712
713
714
715
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
716
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
717
718
719
720
721
722
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
723
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
724
725
726
727
728
729
730
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

731
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
732
733
        self.forward_mode = ForwardMode.EXTEND

734
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
735
        reqs = self.reqs
736
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
737
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
738
        seq_lens = []
739
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
740

741
        # Allocate memory
742
        req_pool_indices = self.alloc_req_slots(bs)
743
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
744

Rin Intachuen's avatar
Rin Intachuen committed
745
        input_embeds = []
746
        extend_input_logprob_token_ids = []
Rin Intachuen's avatar
Rin Intachuen committed
747
748

        pt = 0
749
        for i, req in enumerate(reqs):
750
            req.req_pool_idx = req_pool_indices[i]
751
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
752
            seq_lens.append(seq_len)
753
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
754

755
            if pre_len > 0:
756
757
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
758
                )
759

Rin Intachuen's avatar
Rin Intachuen committed
760
761
762
763
764
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

765
766
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
767
            req.is_retracted = False
768
            pre_lens.append(pre_len)
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                req.extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len,
                    req.extend_input_len,
                    req.seqlen - 1,
                )
            else:
                req.extend_logprob_start_len = 0

            if self.return_logprob:
                # Find input logprob token ids.
                # First, find a global index within origin_input_ids and slide it by 1
                # to compute input logprobs. It is because you need the next token
                # to compute input logprobs. E.g., (chunk size 2)
                #
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [1, 2]
                # extend_input_logprob_token_id = [2, 3]
                #
                # Note that it can also overflow. In this case, we pad it with 0.
                # input_logprobs = [1, 2, 3, 4]
                # fill_ids = [3, 4]
                # extend_input_logprob_token_id = [4, 0]
                global_start_idx, global_end_idx = (
                    len(req.prefix_indices),
                    len(req.fill_ids),
                )
                # Apply logprob_start_len
                if global_start_idx < req.logprob_start_len:
                    global_start_idx = req.logprob_start_len

                logprob_token_ids = req.origin_input_ids[
                    global_start_idx + 1 : global_end_idx + 1
                ]
                extend_input_logprob_token_ids.extend(logprob_token_ids)

                # We will need req.extend_input_len - req.extend_logprob_start_len number of
                # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
                extend_input_logprob_token_ids.extend(
                    [0]
                    * (
                        req.extend_input_len
                        - req.extend_logprob_start_len
                        - len(logprob_token_ids)
                    )
                )

        if self.return_logprob:
            extend_input_logprob_token_ids = torch.tensor(
                extend_input_logprob_token_ids
            )
        else:
            extend_input_logprob_token_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
823
824

        # Set fields
825
826
827
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
828
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
829
830
            self.device, non_blocking=True
        )
831
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
832
833
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
834
835
836
837
838
839
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
840
        self.out_cache_loc = out_cache_loc
841
842

        self.seq_lens_sum = sum(seq_lens)
843
844
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
845
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
846
        self.extend_num_tokens = extend_num_tokens
847
848
849
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
850
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
Lianmin Zheng's avatar
Lianmin Zheng committed
851

852
853
854
855
856
857
858
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
877
878
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

879
880
881
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

882
        # Build sampling info
883
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
884
885
            self,
            self.model_config.vocab_size,
886
        )
887

888
    def mix_with_running(self, running_batch: "ScheduleBatch"):
889
        self.forward_mode = ForwardMode.MIXED
890
        running_bs = running_batch.batch_size()
891
892
893
894
895

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

896
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
897
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
898

899
        self.merge_batch(running_batch)
900
901
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
902

903
904
905
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

906
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
907
        self.prefix_lens.extend(
908
            [
909
                len(r.origin_input_ids) + len(r.output_ids) + delta
910
911
912
                for r in running_batch.reqs
            ]
        )
913
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
914
915
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
916
        self.extend_logprob_start_lens.extend([0] * running_bs)
917

918
919
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
Ying Sheng's avatar
Ying Sheng committed
920
        if self.token_to_kv_pool.available_size() >= bs:
921
922
            return True

Mingyi's avatar
Mingyi committed
923
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
924

925
926
927
928
929
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

930
    def retract_decode(self, server_args: ServerArgs):
931
        """Retract the decoding requests when there is not enough memory."""
932
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
933
934

        # TODO(lsyin): improve retraction policy for radix cache
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
        # For spec decoding, filter_batch API can only filter
        # requests from the back, so we can only retract from the back.
        # TODO(sang): Clean up finish path and support better retract
        # policy.
        if not server_args.speculative_algorithm:
            sorted_indices.sort(
                key=lambda i: (
                    len(self.reqs[i].output_ids),
                    -len(self.reqs[i].origin_input_ids),
                ),
                reverse=True,
            )

        def get_required_tokens(num_reqs: int):
            headroom_for_spec_decode = 0
            if server_args.speculative_algorithm:
                headroom_for_spec_decode += (
                    num_reqs
                    * server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                    + num_reqs * server_args.speculative_num_draft_tokens
                )
            return (
                num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
            )
960
961

        retracted_reqs = []
962
        seq_lens_cpu = self.seq_lens.cpu().numpy()
963
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
964
965
        while (
            self.token_to_kv_pool.available_size()
966
            < get_required_tokens(len(sorted_indices))
967
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
968
969
970
971
972
973
974
975
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

976
            first_iter = False
977
978
979
980
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

981
982
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
983
984
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
985
                ]
986
                self.token_to_kv_pool.free(token_indices)
987
                self.req_to_token_pool.free(req.req_pool_idx)
988
989
990
991
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
992
993
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
994
                ]
995
                self.token_to_kv_pool.free(token_indices)
996
                self.req_to_token_pool.free(req.req_pool_idx)
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
1008
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
1009

1010
        self.filter_batch(keep_indices=sorted_indices)
1011

Liangsheng Yin's avatar
Liangsheng Yin committed
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
1022

1023
1024
1025
1026
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1027
1028
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
1029
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1030
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1031
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1032
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1033
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1034
        self.extend_num_tokens = 0
1035
1036
1037
1038
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )
Ke Bao's avatar
Ke Bao committed
1039

1040
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1041
        self.forward_mode = ForwardMode.DECODE
1042
        if self.spec_algorithm.is_eagle():
1043
1044
            # if spec decoding is used, the decode batch is prepared inside
            # `forward_batch_speculative_generation` after running draft models.
1045
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1046

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        if self.sampling_info.penalizer_orchestrator.is_required:
            if self.enable_overlap:
                # TODO: this can be slow, optimize this.
                delayed_output_ids = torch.tensor(
                    [
                        (
                            req.output_ids[-1]
                            if len(req.output_ids)
                            else req.origin_input_ids[-1]
                        )
                        for req in self.reqs
                    ],
                    dtype=torch.int64,
                    device=self.device,
                )
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    delayed_output_ids
                )
            else:
                self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                    self.output_ids.to(torch.int64)
                )

1070
1071
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1072
1073

        # Alloc mem
1074
        bs = len(self.reqs)
1075
        self.out_cache_loc = self.alloc_token_slots(bs)
1076

1077
1078
1079
1080
1081
1082
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1083
        if self.enable_overlap:
1084
1085
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1086
                (self.req_pool_indices, locs), self.out_cache_loc
1087
1088
1089
1090
1091
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1092
                (self.req_pool_indices, locs), self.out_cache_loc
1093
1094
            )
            self.seq_lens.add_(1)
1095
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1096

1097
1098
    def filter_batch(
        self,
1099
        chunked_req_to_exclude: Optional[Req] = None,
1100
1101
1102
1103
1104
1105
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
1106
1107
                if not self.reqs[i].finished()
                and self.reqs[i] is not chunked_req_to_exclude
1108
1109
1110
            ]

        if keep_indices is None or len(keep_indices) == 0:
1111
1112
1113
1114
            # Filter out all requests
            self.reqs = []
            return

1115
        if len(keep_indices) == len(self.reqs):
1116
1117
1118
            # No need to filter
            return

1119
1120
1121
1122
        keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
            self.device, non_blocking=True
        )

1123
        if self.model_config.is_encoder_decoder:
1124
            self.encoder_lens = self.encoder_lens[keep_indices_device]
1125
1126
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1127
        self.reqs = [self.reqs[i] for i in keep_indices]
1128
1129
        self.req_pool_indices = self.req_pool_indices[keep_indices_device]
        self.seq_lens = self.seq_lens[keep_indices_device]
1130
        self.out_cache_loc = None
1131
        self.seq_lens_sum = self.seq_lens.sum().item()
1132
        self.output_ids = self.output_ids[keep_indices_device]
1133
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1134
        if self.return_logprob:
1135
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1136
            self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
1137
1138
        else:
            self.top_logprobs_nums = None
1139
            self.token_ids_logprobs = None
1140

1141
        self.has_stream = any(req.stream for req in self.reqs)
1142
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1143

1144
        self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1145
        if self.spec_info:
1146
            self.spec_info.filter_batch(keep_indices_device)
Lianmin Zheng's avatar
Lianmin Zheng committed
1147

1148
    def merge_batch(self, other: "ScheduleBatch"):
1149
1150
1151
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1152
        self.sampling_info.merge_batch(other.sampling_info)
1153

1154
1155
1156
1157
1158
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1159
1160
1161
1162
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1163
        self.out_cache_loc = None
1164
        self.seq_lens_sum += other.seq_lens_sum
1165
1166
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1167
1168
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
1169
            self.token_ids_logprobs.extend(other.token_ids_logprobs)
1170
1171
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
1172
            self.token_ids_logprobs.extend([None] * len(other.reqs))
1173
1174
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1175
            self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1176
        self.reqs.extend(other.reqs)
1177

1178
1179
1180
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1181
        self.return_hidden_states |= other.return_hidden_states
1182

1183
1184
1185
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1186
    def get_model_worker_batch(self):
1187
        if self.forward_mode.is_decode_or_idle():
1188
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1189
1190
1191
1192
1193
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1194
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1195
1196
1197
1198
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1199

1200
1201
        global bid
        bid += 1
1202
        return ModelWorkerBatch(
1203
            bid=bid,
1204
1205
1206
1207
1208
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1209
            seq_lens_sum=self.seq_lens_sum,
1210
1211
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1212
            token_ids_logprobs=self.token_ids_logprobs,
Ke Bao's avatar
Ke Bao committed
1213
            global_num_tokens=self.global_num_tokens,
1214
            global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1215
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1216
            extend_num_tokens=self.extend_num_tokens,
1217
1218
1219
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1220
1221
1222
1223
1224
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1225
            lora_paths=[req.lora_path for req in self.reqs],
1226
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1227
            input_embeds=self.input_embeds,
1228
1229
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1230
            capture_hidden_mode=(
1231
                CaptureHiddenMode.FULL
1232
                if self.return_hidden_states
1233
1234
1235
1236
1237
1238
1239
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1240
            ),
1241
            extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1242
1243
        )

1244
    def copy(self):
1245
        # Only contain fields that will be used by process_batch_result
1246
1247
        return ScheduleBatch(
            reqs=self.reqs,
1248
            model_config=self.model_config,
1249
            forward_mode=self.forward_mode,
1250
1251
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1252
            decoding_reqs=self.decoding_reqs,
1253
            spec_algorithm=self.spec_algorithm,
1254
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1255
1256
1257
1258
1259
1260
1261
1262
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1263

1264
@dataclasses.dataclass
1265
class ModelWorkerBatch:
1266
1267
    # The batch id
    bid: int
1268
1269
1270
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1271
    input_ids: torch.Tensor
1272
1273
1274
1275
1276
1277
1278
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1279
1280
1281
    # The sum of all sequence lengths
    seq_lens_sum: int

1282
1283
1284
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]
1285
    token_ids_logprobs: Optional[List[List[int]]]
1286

Ke Bao's avatar
Ke Bao committed
1287
1288
    # For DP attention
    global_num_tokens: Optional[List[int]]
1289
    global_num_tokens_for_logprob: Optional[List[int]]
1290
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1291

1292
    # For extend
1293
    extend_num_tokens: Optional[int]
1294
1295
1296
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]
1297
    extend_input_logprob_token_ids: Optional[torch.Tensor]
1298
1299
1300
1301

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1302
1303
1304
1305
1306
1307
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1308
1309
1310
1311
1312
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1313

Rin Intachuen's avatar
Rin Intachuen committed
1314
1315
1316
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1317
    # Speculative decoding
1318
    spec_algorithm: SpeculativeAlgorithm = None
1319
1320
    spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
    # If set, the output of the batch contains the hidden states of the run.
Lianmin Zheng's avatar
Lianmin Zheng committed
1321
    capture_hidden_mode: CaptureHiddenMode = None
1322

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )