schedule_batch.py 46 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
37
import torch
38
39
import triton
import triton.language as tl
40

Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.global_config import global_config
42
from sglang.srt.configs.model_config import ModelConfig
43
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Lianmin Zheng's avatar
Lianmin Zheng committed
47
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
from sglang.srt.sampling.sampling_params import SamplingParams
50
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
53
54
55
if TYPE_CHECKING:
    from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm


Liangsheng Yin's avatar
Liangsheng Yin committed
56
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
57

58
59
# Put some global args for easy access
global_server_args_dict = {
60
61
62
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
63
    "disable_mla": ServerArgs.disable_mla,
64
    "torchao_config": ServerArgs.torchao_config,
65
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
66
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
67
    "enable_ep_moe": ServerArgs.enable_ep_moe,
68
69
}

Lianmin Zheng's avatar
Lianmin Zheng committed
70

Ying Sheng's avatar
Ying Sheng committed
71
72
73
logger = logging.getLogger(__name__)


74
75
76
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
77

78
    def to_json(self):
79
        raise NotImplementedError()
80
81
82


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
83
    def __init__(self, matched: Union[int, List[int]]):
84
85
86
        super().__init__()
        self.matched = matched

87
88
89
90
91
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
92
93


94
95
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
96
        super().__init__()
97
        self.matched = matched
98

99
100
101
102
103
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
104
105


106
107
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
108
        super().__init__()
109
        self.length = length
110

111
112
113
114
115
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
116
117
118


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    def __init__(self, message="Unknown error"):
120
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
121
        self.message = message
122

123
124
125
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
126
            "message": self.message,
127
        }
128

Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
131
class ImageInputs:
132
133
    """The image related inputs."""

134
    pixel_values: Union[torch.Tensor, np.array]
135
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
138
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
139
140
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
141
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
142

143
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
144
145
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
146

Yineng Zhang's avatar
Yineng Zhang committed
147
148
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
149
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151

    @staticmethod
152
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
153
154
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
155
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
156
        )
157
158
159

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
160
161
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
162
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
163
164
165
166
167
168
169
170
171
172
173
174

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
175
176
        return ret

177
    def merge(self, other):
178
179
180
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

181
182
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
183
184
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
185
186
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
187
188
189
190

        optional_args = [
            "image_sizes",
            "image_offsets",
191
            "image_pad_len",
192
193
194
195
196
197
198
199
200
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
201

Lianmin Zheng's avatar
Lianmin Zheng committed
202
class Req:
203
    """The input and output status of a request."""
204

205
206
207
208
209
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
210
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
213
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
        stream: bool = False,
214
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
215
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
216
        input_embeds: Optional[List[List[float]]] = None,
217
        session_id: Optional[str] = None,
218
        eos_token_ids: Optional[Set[int]] = None,
219
    ):
220
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
221
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
222
        self.origin_input_text = origin_input_text
223
224
225
226
227
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
228
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
229
        self.output_ids = []  # Each decode stage's output ids
230
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
231
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
232
        self.input_embeds = input_embeds
233

Lianmin Zheng's avatar
Lianmin Zheng committed
234
        # Sampling info
235
        self.sampling_params = sampling_params
236
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
237

238
        # Memory pool info
239
240
        self.req_pool_idx = None

241
242
243
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
244
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
245
        self.stream = stream
246
        self.eos_token_ids = eos_token_ids
247

248
        # For incremental decoding
249
250
251
252
253
254
255
256
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
257
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
258
259
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
260
        self.decoded_text = ""
261

262
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
263
        self.image_inputs: Optional[ImageInputs] = None
264

265
266
        # Prefix info
        self.prefix_indices = []
267
        # Tokens to run prefill. input_tokens - shared_prefix_tokens.
268
        self.extend_input_len = 0
269
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
270
271

        # Chunked prefill
272
        self.is_being_chunked = 0
273

274
275
276
        # For retraction
        self.is_retracted = False

277
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
278
        self.return_logprob = return_logprob
279
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
280
        self.top_logprobs_num = top_logprobs_num
281
282

        # Logprobs (return value)
Lianmin Zheng's avatar
Lianmin Zheng committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        self.input_token_logprobs_val = None
        self.input_token_logprobs_idx = None
        self.input_top_logprobs_val = None
        self.input_top_logprobs_idx = None

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
            ) = self.output_top_logprobs_idx = None
297
298

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
299
300
301
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
302
303
304
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

305
        # Embedding (return values)
306
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
307

308
        # Constrained decoding
309
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
310

311
312
313
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

314
    def extend_image_inputs(self, image_inputs):
315
316
317
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
318
            self.image_inputs.merge(image_inputs)
319

320
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
321
        # Whether request reached finished condition
322
323
        return self.finished_reason is not None

324
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
325
        self.fill_ids = self.origin_input_ids + self.output_ids
326
        if tree_cache is not None:
327
            # tree cache is None if the prefix is not computed with tree cache.
328
329
330
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
331
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
332

333
    def adjust_max_prefix_ids(self):
334
335
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
336
337
338
339

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
340
341
342
343
344

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

345
        if self.return_logprob:
346
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
347

348
        max_prefix_len = max(max_prefix_len, 0)
349
        return self.fill_ids[:max_prefix_len]
350

Liangsheng Yin's avatar
Liangsheng Yin committed
351
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
352
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
353
354
355
356
357
358
359
360
361
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
362
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
363

364
    def get_next_inc_detokenization(self):
365
366
        if self.tokenizer is None:
            return False, ""
367
368
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
369
370
371
372
373

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
374
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
375
376
377
378
379
380
381
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
382
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
383
384

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
385

386
    def check_finished(self):
387
        if self.finished():
388
389
            return

390
391
392
393
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
394
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
395
396
397
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
398
399
            return

400
        last_token_id = self.output_ids[-1]
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
419

420
        # Check stop strings
421
422
423
424
425
426
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
427
                if stop_str in tail_str or stop_str in self.decoded_text:
428
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
429
430
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
431
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
432
433
434
435
436
437
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
438
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
439
        all_ids = self.tokenizer.encode(all_text)
440
        if not all_ids:
havetc's avatar
havetc committed
441
            logger.warning("Encoded all_text resulted in empty all_ids")
442
443
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
444
        prompt_tokens = len(self.origin_input_ids_unpadded)
445
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
446
            logger.warning("prompt_tokens is larger than encoded all_ids")
447
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
448
449
450

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
451
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
470

471
472
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
473
474
475
476

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
477
478
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
479
480
481
                    k = k + 1
                else:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
482
483
484
485
            self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
            self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
            self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
            self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
486
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
487
            self.last_update_decode_tokens = len(self.output_ids) - k
488

Liangsheng Yin's avatar
Liangsheng Yin committed
489
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
490

491
492
493
494
495
496
497
498
499
500
501
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True

        # For incremental logprobs
        # TODO: Fix the `logprob_start_len`
        self.last_update_decode_tokens = 0
        self.logprob_start_len = 10**9

Lianmin Zheng's avatar
Lianmin Zheng committed
502
    def __repr__(self):
503
504
505
506
        return (
            f"rid(n={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508


509
510
511
bid = 0


512
@dataclasses.dataclass
513
class ScheduleBatch:
514
    """Store all information of a batch on the scheduler."""
515

516
    # Request, memory pool, and cache
517
    reqs: List[Req]
518
519
520
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
521

522
    # Batch configs
523
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
524
    forward_mode: ForwardMode = None
525
526
527
    enable_overlap: bool = False

    # Sampling info
528
    sampling_info: SamplingBatchInfo = None
529
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
530

531
    # Batched arguments to model runner
532
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
533
    input_embeds: torch.Tensor = None
534
535
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
536
    # The output locations of the KV cache
537
    out_cache_loc: torch.Tensor = None
538
539
    output_ids: torch.Tensor = None

540
541
542
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
543
544
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
545
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
546

547
    # For processing logprobs
548
    return_logprob: bool = False
549
550
551
552
553
554
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
555
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
556
    extend_logprob_start_lens: List[int] = None
557

558
559
560
561
562
563
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

564
565
566
    # Stream
    has_stream: bool = False

567
568
    # Has grammar
    has_grammar: bool = False
569

570
    # Device
571
572
    device: str = "cuda"

573
    # Speculative decoding
574
    spec_algorithm: SpeculativeAlgorithm = None
575
576
    spec_info: Optional[SpecInfo] = None

577
    @classmethod
578
579
    def init_new(
        cls,
580
        reqs: List[Req],
581
582
583
584
585
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
586
        spec_algorithm: SpeculativeAlgorithm,
587
    ):
588
589
590
591
592
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
593
            model_config=model_config,
594
            enable_overlap=enable_overlap,
595
596
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
597
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
598
            device=req_to_token_pool.device,
599
            spec_algorithm=spec_algorithm,
Lianmin Zheng's avatar
Lianmin Zheng committed
600
601
        )

602
    def batch_size(self):
603
        return len(self.reqs)
604

Lianmin Zheng's avatar
Lianmin Zheng committed
605
606
607
    def is_empty(self):
        return len(self.reqs) == 0

608
    def alloc_req_slots(self, num_reqs: int):
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
626
627
628
629
630
631
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
632
633
634
635
636
637
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
668
                # NOTE: the encoder part should be considered as a whole
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
694
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
695
696
697
698
699
700
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
701
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
702
703
704
705
706
707
708
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

709
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
710
711
        self.forward_mode = ForwardMode.EXTEND

712
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
713
        reqs = self.reqs
714
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
715
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
716
        seq_lens = []
717
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
718

719
        # Allocate memory
720
        req_pool_indices = self.alloc_req_slots(bs)
721
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
722

Rin Intachuen's avatar
Rin Intachuen committed
723
724
725
        input_embeds = []

        pt = 0
726
        for i, req in enumerate(reqs):
727
728
729
730
731
732
733
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

734
            req.req_pool_idx = req_pool_indices[i]
735
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
736
            seq_lens.append(seq_len)
737
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
738

739
            if pre_len > 0:
740
741
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
742
                )
743

Rin Intachuen's avatar
Rin Intachuen committed
744
745
746
747
748
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

749
750
751
752
753
754
755
756
757
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
758
            req.is_retracted = False
759
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761

        # Set fields
762
763
764
765
766
767
768
769
770
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
771
772
773
774
775
776
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
777
        self.out_cache_loc = out_cache_loc
778
779

        self.seq_lens_sum = sum(seq_lens)
780
781
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
782
        self.extend_num_tokens = extend_num_tokens
783
784
785
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
786

787
788
789
790
791
792
793
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
812
813
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

814
815
816
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

817
        # Build sampling info
818
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
819
820
            self,
            self.model_config.vocab_size,
821
            enable_overlap_schedule=self.enable_overlap,
822
        )
823

824
    def mix_with_running(self, running_batch: "ScheduleBatch"):
825
        self.forward_mode = ForwardMode.MIXED
826
        running_bs = running_batch.batch_size()
827
828
829
830
831

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

832
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
833
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
834

835
        self.merge_batch(running_batch)
836
837
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
838

839
840
841
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

842
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
843
        self.prefix_lens.extend(
844
            [
845
                len(r.origin_input_ids) + len(r.output_ids) + delta
846
847
848
                for r in running_batch.reqs
            ]
        )
849
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
850
851
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
852
        self.extend_logprob_start_lens.extend([0] * running_bs)
853

854
855
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
Ying Sheng's avatar
Ying Sheng committed
856
        if self.token_to_kv_pool.available_size() >= bs:
857
858
            return True

Mingyi's avatar
Mingyi committed
859
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
860

861
862
863
864
865
866
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
867
        """Retract the decoding requests when there is not enough memory."""
868
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
869
870

        # TODO(lsyin): improve retraction policy for radix cache
871
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
872
873
874
875
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
876
877
878
879
            reverse=True,
        )

        retracted_reqs = []
880
        seq_lens_cpu = self.seq_lens.cpu().numpy()
881
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
882
883
884
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
885
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
886
887
888
889
890
891
892
893
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

894
            first_iter = False
895
896
897
898
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

899
900
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
901
902
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
903
                ]
904
                self.token_to_kv_pool.free(token_indices)
905
                self.req_to_token_pool.free(req.req_pool_idx)
906
907
908
909
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
910
911
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
912
                ]
913
                self.token_to_kv_pool.free(token_indices)
914
                self.req_to_token_pool.free(req.req_pool_idx)
915
916
917
918
919
920
921
922
923
924
925

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
926
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
927

928
        self.filter_batch(keep_indices=sorted_indices)
929

Liangsheng Yin's avatar
Liangsheng Yin committed
930
931
932
933
934
935
936
937
938
939
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
940

941
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
942
        jump_forward_reqs = []
943
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
944
945

        for i, req in enumerate(self.reqs):
946
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
947
948
949
950
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
951
952
953
954
955
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
956
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
957
958
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
959
960
                        continue

sglang's avatar
sglang committed
961
962
963
                    (
                        jump_forward_str,
                        next_state,
964
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
965

Lianmin Zheng's avatar
Lianmin Zheng committed
966
967
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
968
969
970
971
972
973
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
974

975
976
977
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
978
                    # insert the old request into tree_cache
979
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
980

Liangsheng Yin's avatar
Liangsheng Yin committed
981
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
982
                    if req.image_inputs is not None:
983
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
984
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
985
986
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
987
                    jump_forward_reqs.append(req)
988
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
989

990
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
991

Liangsheng Yin's avatar
Liangsheng Yin committed
992
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
993

994
995
996
997
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
998
999
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
1000
1001
1002
1003
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1004
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1005
        self.extend_num_tokens = 0
1006
1007
1008
1009
1010
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
            enable_overlap_schedule=self.enable_overlap,
        )
Ke Bao's avatar
Ke Bao committed
1011

1012
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1013
        self.forward_mode = ForwardMode.DECODE
1014
1015
        if self.spec_algorithm.is_eagle():
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1016

1017
1018
        self.input_ids = self.output_ids
        self.output_ids = None
1019
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021

        # Alloc mem
1022
        bs = len(self.reqs)
1023
        self.out_cache_loc = self.alloc_token_slots(bs)
1024

1025
1026
1027
1028
1029
1030
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1031
        if self.enable_overlap:
1032
1033
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1034
                (self.req_pool_indices, locs), self.out_cache_loc
1035
1036
1037
1038
1039
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1040
                (self.req_pool_indices, locs), self.out_cache_loc
1041
1042
            )
            self.seq_lens.add_(1)
1043
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1044

1045
1046
    def filter_batch(
        self,
1047
        being_chunked_req: Optional[Req] = None,
1048
1049
1050
1051
1052
1053
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1054
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1055
1056
1057
            ]

        if keep_indices is None or len(keep_indices) == 0:
1058
1059
1060
1061
            # Filter out all requests
            self.reqs = []
            return

1062
        if len(keep_indices) == len(self.reqs):
1063
1064
1065
            # No need to filter
            return

1066
1067
1068
1069
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1070
        self.reqs = [self.reqs[i] for i in keep_indices]
1071
1072
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
1073
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1074
        self.req_pool_indices = self.req_pool_indices[new_indices]
1075
        self.seq_lens = self.seq_lens[new_indices]
1076
        self.out_cache_loc = None
1077
        self.seq_lens_sum = self.seq_lens.sum().item()
1078
        self.output_ids = self.output_ids[new_indices]
1079
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1080
        if self.return_logprob:
1081
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1082
1083
        else:
            self.top_logprobs_nums = None
1084

1085
        self.has_stream = any(req.stream for req in self.reqs)
1086
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1087

1088
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1089

1090
    def merge_batch(self, other: "ScheduleBatch"):
1091
1092
1093
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1094
        self.sampling_info.merge_batch(other.sampling_info)
1095

1096
1097
1098
1099
1100
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1101
1102
1103
1104
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1105
        self.out_cache_loc = None
1106
        self.seq_lens_sum += other.seq_lens_sum
1107
1108
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1109
1110
1111
1112
1113
1114
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1115
        self.reqs.extend(other.reqs)
1116

1117
1118
1119
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1120

1121
1122
1123
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1124
    def get_model_worker_batch(self):
1125
        if self.forward_mode.is_decode_or_idle():
1126
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1127
1128
1129
1130
1131
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1132
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1133
1134
1135
1136
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1137

1138
1139
1140
        global bid
        bid += 1

1141
        return ModelWorkerBatch(
1142
            bid=bid,
1143
1144
1145
1146
1147
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1148
            seq_lens_sum=self.seq_lens_sum,
1149
1150
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1151
            global_num_tokens=self.global_num_tokens,
1152
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1153
            extend_num_tokens=self.extend_num_tokens,
1154
1155
1156
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1157
1158
1159
1160
1161
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1162
            lora_paths=[req.lora_path for req in self.reqs],
1163
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1164
            input_embeds=self.input_embeds,
1165
1166
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1167
1168
1169
1170
1171
            capture_hidden_mode=(
                getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
                if self.spec_info
                else CaptureHiddenMode.NULL
            ),
1172
1173
        )

1174
    def copy(self):
1175
        # Only contain fields that will be used by process_batch_result
1176
1177
        return ScheduleBatch(
            reqs=self.reqs,
1178
            model_config=self.model_config,
1179
            forward_mode=self.forward_mode,
1180
1181
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1182
            decoding_reqs=self.decoding_reqs,
1183
            spec_algorithm=self.spec_algorithm,
1184
1185
1186
1187
1188
1189
1190
1191
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1192

1193
@dataclasses.dataclass
1194
class ModelWorkerBatch:
1195
1196
    # The batch id
    bid: int
1197
1198
1199
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1200
    input_ids: torch.Tensor
1201
1202
1203
1204
1205
1206
1207
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1208
1209
1210
    # The sum of all sequence lengths
    seq_lens_sum: int

1211
1212
1213
1214
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1215
1216
    # For DP attention
    global_num_tokens: Optional[List[int]]
1217
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1218

1219
    # For extend
1220
    extend_num_tokens: Optional[int]
1221
1222
1223
1224
1225
1226
1227
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1228
1229
1230
1231
1232
1233
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1234
1235
1236
1237
1238
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1239

Rin Intachuen's avatar
Rin Intachuen committed
1240
1241
1242
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1243
    # Speculative decoding
1244
    spec_algorithm: SpeculativeAlgorithm = None
1245
    spec_info: Optional[SpecInfo] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
    capture_hidden_mode: CaptureHiddenMode = None
1247

1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )