schedule_batch.py 46.6 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
37
import torch
38
39
import triton
import triton.language as tl
40

Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.global_config import global_config
42
from sglang.srt.configs.model_config import ModelConfig
43
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Lianmin Zheng's avatar
Lianmin Zheng committed
47
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
from sglang.srt.sampling.sampling_params import SamplingParams
50
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
53
54
if TYPE_CHECKING:
    from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
55
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
56

57
58
# Put some global args for easy access
global_server_args_dict = {
59
60
61
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
62
    "disable_mla": ServerArgs.disable_mla,
63
    "torchao_config": ServerArgs.torchao_config,
64
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
65
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
66
    "enable_ep_moe": ServerArgs.enable_ep_moe,
67
    "device": ServerArgs.device,
68
69
}

Ying Sheng's avatar
Ying Sheng committed
70
71
72
logger = logging.getLogger(__name__)


73
74
75
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
76

77
    def to_json(self):
78
        raise NotImplementedError()
79
80
81


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
82
    def __init__(self, matched: Union[int, List[int]]):
83
84
85
        super().__init__()
        self.matched = matched

86
87
88
89
90
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
91
92


93
94
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
95
        super().__init__()
96
        self.matched = matched
97

98
99
100
101
102
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
103
104


105
106
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
107
        super().__init__()
108
        self.length = length
109

110
111
112
113
114
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
115
116
117


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
118
    def __init__(self, message="Unknown error"):
119
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        self.message = message
121

122
123
124
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
125
            "message": self.message,
126
        }
127

Lianmin Zheng's avatar
Lianmin Zheng committed
128

129
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
130
class ImageInputs:
131
132
    """The image related inputs."""

133
    pixel_values: Union[torch.Tensor, np.array]
134
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
137
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
138
139
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
140
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
141

142
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
145

Yineng Zhang's avatar
Yineng Zhang committed
146
147
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
148
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
149

Mick's avatar
Mick committed
150
151
152
153
154
155
156
157
158
159
    # MiniCPMV related
    # All the images in the batch should share the same special image
    # bound token ids.
    im_start_id: Optional[torch.Tensor] = None
    im_end_id: Optional[torch.Tensor] = None
    slice_start_id: Optional[torch.Tensor] = None
    slice_end_id: Optional[torch.Tensor] = None

    tgt_sizes: Optional[list] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
160
    @staticmethod
161
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
162
163
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
164
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
165
        )
166
167
168

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
169
170
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
171
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
172
173
174
175
176
177
178

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
Mick's avatar
Mick committed
179
180
181
182
183
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
184
185
186
187
188
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
189
190
        return ret

191
    def merge(self, other):
192
193
194
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

195
196
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
197
198
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
199
200
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
201
202
203
204

        optional_args = [
            "image_sizes",
            "image_offsets",
205
            "image_pad_len",
206
207
208
209
210
211
212
213
214
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
215

Lianmin Zheng's avatar
Lianmin Zheng committed
216
class Req:
217
    """The input and output status of a request."""
218

219
220
221
222
223
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
224
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
227
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
        stream: bool = False,
228
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
229
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
230
        input_embeds: Optional[List[List[float]]] = None,
231
        session_id: Optional[str] = None,
232
        eos_token_ids: Optional[Set[int]] = None,
233
    ):
234
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
235
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
236
        self.origin_input_text = origin_input_text
237
238
239
240
241
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
242
        self.origin_input_ids = origin_input_ids
243
244
245
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
246
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
247
        self.input_embeds = input_embeds
248

Lianmin Zheng's avatar
Lianmin Zheng committed
249
        # Sampling info
250
        self.sampling_params = sampling_params
251
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
252

253
        # Memory pool info
254
255
        self.req_pool_idx = None

256
257
258
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
259
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
260
        self.stream = stream
261
        self.eos_token_ids = eos_token_ids
262

263
        # For incremental decoding
264
265
266
267
268
269
270
271
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
272
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
273
274
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
275
        self.decoded_text = ""
276

277
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
278
        self.image_inputs: Optional[ImageInputs] = None
279

280
281
        # Prefix info
        self.prefix_indices = []
282
        # Tokens to run prefill. input_tokens - shared_prefix_tokens.
283
        # Updated if chunked.
284
        self.extend_input_len = 0
285
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
286
287

        # Chunked prefill
288
        self.is_being_chunked = 0
289

290
291
292
        # For retraction
        self.is_retracted = False

293
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
294
        self.return_logprob = return_logprob
295
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
296
        self.top_logprobs_num = top_logprobs_num
297
298

        # Logprobs (return value)
299
300
301
302
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
303
304
305
306
307
308
309
310
311
312

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
            ) = self.output_top_logprobs_idx = None
313
314

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
315
316
317
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
318
319
320
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

321
        # Embedding (return values)
322
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
323

324
        # Constrained decoding
325
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
326

327
328
329
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

330
    def extend_image_inputs(self, image_inputs):
331
332
333
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
334
            self.image_inputs.merge(image_inputs)
335

336
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
337
        # Whether request reached finished condition
338
339
        return self.finished_reason is not None

340
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
341
        self.fill_ids = self.origin_input_ids + self.output_ids
342
        if tree_cache is not None:
343
            # tree cache is None if the prefix is not computed with tree cache.
344
345
346
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
347
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
348

349
    def adjust_max_prefix_ids(self):
350
351
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
352
353
354
355

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
356
357
358
359
360

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

361
        if self.return_logprob:
362
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
363

364
        max_prefix_len = max(max_prefix_len, 0)
365
        return self.fill_ids[:max_prefix_len]
366

Liangsheng Yin's avatar
Liangsheng Yin committed
367
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
368
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
369
370
371
372
373
374
375
376
377
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
378
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
379

380
    def get_next_inc_detokenization(self):
381
382
        if self.tokenizer is None:
            return False, ""
383
384
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
385
386
387
388
389

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
390
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
391
392
393
394
395
396
397
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
398
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
399
400

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
401

402
    def check_finished(self):
403
        if self.finished():
404
405
            return

406
407
408
409
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
410
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
411
412
413
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
414
415
            return

416
        last_token_id = self.output_ids[-1]
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
435

436
        # Check stop strings
437
438
439
440
441
442
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
443
                if stop_str in tail_str or stop_str in self.decoded_text:
444
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
445
446
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
447
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
448
449
450
451
452
453
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
454
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
455
        all_ids = self.tokenizer.encode(all_text)
456
        if not all_ids:
havetc's avatar
havetc committed
457
            logger.warning("Encoded all_text resulted in empty all_ids")
458
459
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
460
        prompt_tokens = len(self.origin_input_ids_unpadded)
461
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
462
            logger.warning("prompt_tokens is larger than encoded all_ids")
463
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
464
465
466

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
467
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
486

487
488
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
489
490
491
492

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
493
494
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
495
496
497
                    k = k + 1
                else:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
498
499
500
501
            self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
            self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
            self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
            self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
502
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
503
            self.last_update_decode_tokens = len(self.output_ids) - k
504

Liangsheng Yin's avatar
Liangsheng Yin committed
505
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
506

507
508
509
510
511
512
513
514
515
516
517
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True

        # For incremental logprobs
        # TODO: Fix the `logprob_start_len`
        self.last_update_decode_tokens = 0
        self.logprob_start_len = 10**9

Lianmin Zheng's avatar
Lianmin Zheng committed
518
    def __repr__(self):
519
520
521
522
        return (
            f"rid(n={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
523
524


525
526
527
bid = 0


528
@dataclasses.dataclass
529
class ScheduleBatch:
530
    """Store all information of a batch on the scheduler."""
531

532
    # Request, memory pool, and cache
533
    reqs: List[Req]
534
535
536
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
537

538
    # Batch configs
539
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
540
    forward_mode: ForwardMode = None
541
542
543
    enable_overlap: bool = False

    # Sampling info
544
    sampling_info: SamplingBatchInfo = None
545
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
546

547
    # Batched arguments to model runner
548
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
549
    input_embeds: torch.Tensor = None
550
551
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
552
    # The output locations of the KV cache
553
    out_cache_loc: torch.Tensor = None
554
555
    output_ids: torch.Tensor = None

556
557
558
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
559
560
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
561
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
562

563
    # For processing logprobs
564
    return_logprob: bool = False
565
566
567
568
569
570
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
571
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
572
    extend_logprob_start_lens: List[int] = None
573

574
575
576
577
578
579
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

580
581
582
    # Stream
    has_stream: bool = False

583
584
    # Has grammar
    has_grammar: bool = False
585

586
    # Device
587
588
    device: str = "cuda"

589
    # Speculative decoding
590
    spec_algorithm: SpeculativeAlgorithm = None
591
592
    spec_info: Optional[SpecInfo] = None

593
    @classmethod
594
595
    def init_new(
        cls,
596
        reqs: List[Req],
597
598
599
600
601
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
602
        spec_algorithm: SpeculativeAlgorithm,
603
    ):
604
605
606
607
608
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
609
            model_config=model_config,
610
            enable_overlap=enable_overlap,
611
612
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
613
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
614
            device=req_to_token_pool.device,
615
            spec_algorithm=spec_algorithm,
Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
        )

618
    def batch_size(self):
619
        return len(self.reqs)
620

Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
623
    def is_empty(self):
        return len(self.reqs) == 0

624
    def alloc_req_slots(self, num_reqs: int):
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
642
643
644
645
646
647
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
648
649
650
651
652
653
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
684
                # NOTE: the encoder part should be considered as a whole
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
710
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
711
712
713
714
715
716
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
717
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
718
719
720
721
722
723
724
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

725
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
726
727
        self.forward_mode = ForwardMode.EXTEND

728
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
729
        reqs = self.reqs
730
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
731
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
732
        seq_lens = []
733
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
734

735
        # Allocate memory
736
        req_pool_indices = self.alloc_req_slots(bs)
737
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
738

Rin Intachuen's avatar
Rin Intachuen committed
739
740
741
        input_embeds = []

        pt = 0
742
        for i, req in enumerate(reqs):
743
744
745
746
747
748
749
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

750
            req.req_pool_idx = req_pool_indices[i]
751
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
752
            seq_lens.append(seq_len)
753
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
754

755
            if pre_len > 0:
756
757
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
758
                )
759

Rin Intachuen's avatar
Rin Intachuen committed
760
761
762
763
764
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

765
766
767
768
769
770
771
772
773
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
774
            req.is_retracted = False
775
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
776
777

        # Set fields
778
779
780
781
782
783
784
785
786
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
787
788
789
790
791
792
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
793
        self.out_cache_loc = out_cache_loc
794
795

        self.seq_lens_sum = sum(seq_lens)
796
797
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
798
        self.extend_num_tokens = extend_num_tokens
799
800
801
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
802

803
804
805
806
807
808
809
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
828
829
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

830
831
832
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

833
        # Build sampling info
834
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
835
836
            self,
            self.model_config.vocab_size,
837
            enable_overlap_schedule=self.enable_overlap,
838
        )
839

840
    def mix_with_running(self, running_batch: "ScheduleBatch"):
841
        self.forward_mode = ForwardMode.MIXED
842
        running_bs = running_batch.batch_size()
843
844
845
846
847

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

848
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
849
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
850

851
        self.merge_batch(running_batch)
852
853
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
854

855
856
857
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

858
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
859
        self.prefix_lens.extend(
860
            [
861
                len(r.origin_input_ids) + len(r.output_ids) + delta
862
863
864
                for r in running_batch.reqs
            ]
        )
865
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
866
867
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
868
        self.extend_logprob_start_lens.extend([0] * running_bs)
869

870
871
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
Ying Sheng's avatar
Ying Sheng committed
872
        if self.token_to_kv_pool.available_size() >= bs:
873
874
            return True

Mingyi's avatar
Mingyi committed
875
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
876

877
878
879
880
881
882
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
883
        """Retract the decoding requests when there is not enough memory."""
884
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
885
886

        # TODO(lsyin): improve retraction policy for radix cache
887
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
888
889
890
891
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
892
893
894
895
            reverse=True,
        )

        retracted_reqs = []
896
        seq_lens_cpu = self.seq_lens.cpu().numpy()
897
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
898
899
900
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
901
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
902
903
904
905
906
907
908
909
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

910
            first_iter = False
911
912
913
914
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

915
916
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
917
918
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
919
                ]
920
                self.token_to_kv_pool.free(token_indices)
921
                self.req_to_token_pool.free(req.req_pool_idx)
922
923
924
925
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
926
927
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
928
                ]
929
                self.token_to_kv_pool.free(token_indices)
930
                self.req_to_token_pool.free(req.req_pool_idx)
931
932
933
934
935
936
937
938
939
940
941

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
942
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
943

944
        self.filter_batch(keep_indices=sorted_indices)
945

Liangsheng Yin's avatar
Liangsheng Yin committed
946
947
948
949
950
951
952
953
954
955
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
956

957
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
958
        jump_forward_reqs = []
959
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
960
961

        for i, req in enumerate(self.reqs):
962
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
963
964
965
966
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
967
968
969
970
971
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
972
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
973
974
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
975
976
                        continue

sglang's avatar
sglang committed
977
978
979
                    (
                        jump_forward_str,
                        next_state,
980
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
981

Lianmin Zheng's avatar
Lianmin Zheng committed
982
983
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
984
985
986
987
988
989
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
990

991
992
993
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
994
                    # insert the old request into tree_cache
995
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
996

Liangsheng Yin's avatar
Liangsheng Yin committed
997
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
998
                    if req.image_inputs is not None:
999
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
1000
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
1001
1002
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1003
                    jump_forward_reqs.append(req)
1004
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1005

1006
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
1007

Liangsheng Yin's avatar
Liangsheng Yin committed
1008
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
1009

1010
1011
1012
1013
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1014
1015
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
1016
1017
1018
1019
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1020
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1021
        self.extend_num_tokens = 0
1022
1023
1024
1025
1026
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
            enable_overlap_schedule=self.enable_overlap,
        )
Ke Bao's avatar
Ke Bao committed
1027

1028
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1029
        self.forward_mode = ForwardMode.DECODE
1030
1031
        if self.spec_algorithm.is_eagle():
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1032

1033
1034
        self.input_ids = self.output_ids
        self.output_ids = None
1035
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037

        # Alloc mem
1038
        bs = len(self.reqs)
1039
        self.out_cache_loc = self.alloc_token_slots(bs)
1040

1041
1042
1043
1044
1045
1046
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1047
        if self.enable_overlap:
1048
1049
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1050
                (self.req_pool_indices, locs), self.out_cache_loc
1051
1052
1053
1054
1055
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1056
                (self.req_pool_indices, locs), self.out_cache_loc
1057
1058
            )
            self.seq_lens.add_(1)
1059
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1060

1061
1062
    def filter_batch(
        self,
1063
        being_chunked_req: Optional[Req] = None,
1064
1065
1066
1067
1068
1069
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1070
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1071
1072
1073
            ]

        if keep_indices is None or len(keep_indices) == 0:
1074
1075
1076
1077
            # Filter out all requests
            self.reqs = []
            return

1078
        if len(keep_indices) == len(self.reqs):
1079
1080
1081
            # No need to filter
            return

1082
1083
1084
1085
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1086
        self.reqs = [self.reqs[i] for i in keep_indices]
1087
1088
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
1089
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1090
        self.req_pool_indices = self.req_pool_indices[new_indices]
1091
        self.seq_lens = self.seq_lens[new_indices]
1092
        self.out_cache_loc = None
1093
        self.seq_lens_sum = self.seq_lens.sum().item()
1094
        self.output_ids = self.output_ids[new_indices]
1095
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1096
        if self.return_logprob:
1097
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1098
1099
        else:
            self.top_logprobs_nums = None
1100

1101
        self.has_stream = any(req.stream for req in self.reqs)
1102
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1103

1104
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1105

1106
    def merge_batch(self, other: "ScheduleBatch"):
1107
1108
1109
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1110
        self.sampling_info.merge_batch(other.sampling_info)
1111

1112
1113
1114
1115
1116
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1117
1118
1119
1120
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1121
        self.out_cache_loc = None
1122
        self.seq_lens_sum += other.seq_lens_sum
1123
1124
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1125
1126
1127
1128
1129
1130
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1131
        self.reqs.extend(other.reqs)
1132

1133
1134
1135
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1136

1137
1138
1139
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1140
    def get_model_worker_batch(self):
1141
        if self.forward_mode.is_decode_or_idle():
1142
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1143
1144
1145
1146
1147
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1148
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1149
1150
1151
1152
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1153

1154
1155
        global bid
        bid += 1
1156
        return ModelWorkerBatch(
1157
            bid=bid,
1158
1159
1160
1161
1162
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1163
            seq_lens_sum=self.seq_lens_sum,
1164
1165
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1166
            global_num_tokens=self.global_num_tokens,
1167
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1168
            extend_num_tokens=self.extend_num_tokens,
1169
1170
1171
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1172
1173
1174
1175
1176
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1177
            lora_paths=[req.lora_path for req in self.reqs],
1178
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1179
            input_embeds=self.input_embeds,
1180
1181
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1182
1183
1184
1185
1186
            capture_hidden_mode=(
                getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
                if self.spec_info
                else CaptureHiddenMode.NULL
            ),
1187
1188
        )

1189
    def copy(self):
1190
        # Only contain fields that will be used by process_batch_result
1191
1192
        return ScheduleBatch(
            reqs=self.reqs,
1193
            model_config=self.model_config,
1194
            forward_mode=self.forward_mode,
1195
1196
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1197
            decoding_reqs=self.decoding_reqs,
1198
            spec_algorithm=self.spec_algorithm,
1199
1200
1201
1202
1203
1204
1205
1206
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1207

1208
@dataclasses.dataclass
1209
class ModelWorkerBatch:
1210
1211
    # The batch id
    bid: int
1212
1213
1214
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1215
    input_ids: torch.Tensor
1216
1217
1218
1219
1220
1221
1222
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1223
1224
1225
    # The sum of all sequence lengths
    seq_lens_sum: int

1226
1227
1228
1229
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1230
1231
    # For DP attention
    global_num_tokens: Optional[List[int]]
1232
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1233

1234
    # For extend
1235
    extend_num_tokens: Optional[int]
1236
1237
1238
1239
1240
1241
1242
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1243
1244
1245
1246
1247
1248
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1249
1250
1251
1252
1253
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1254

Rin Intachuen's avatar
Rin Intachuen committed
1255
1256
1257
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1258
    # Speculative decoding
1259
    spec_algorithm: SpeculativeAlgorithm = None
1260
    spec_info: Optional[SpecInfo] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
    capture_hidden_mode: CaptureHiddenMode = None
1262

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )