schedule_batch.py 43.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
23
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
24
25
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
26
27
28
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
29

30
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
31
import logging
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
35
import torch
36
37
import triton
import triton.language as tl
38

Liangsheng Yin's avatar
Liangsheng Yin committed
39
from sglang.global_config import global_config
40
from sglang.srt.configs.model_config import ModelConfig
41
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
42
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
45
from sglang.srt.model_executor.forward_batch_info import ForwardMode
46
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
47
from sglang.srt.sampling.sampling_params import SamplingParams
48
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
49
50

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
51

52
53
# Put some global args for easy access
global_server_args_dict = {
54
55
56
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
57
    "disable_mla": ServerArgs.disable_mla,
58
    "torchao_config": ServerArgs.torchao_config,
59
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
60
    "enable_dp_attention": ServerArgs.enable_dp_attention,
61
62
}

Lianmin Zheng's avatar
Lianmin Zheng committed
63

Ying Sheng's avatar
Ying Sheng committed
64
65
66
logger = logging.getLogger(__name__)


67
68
69
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
70

71
    def to_json(self):
72
        raise NotImplementedError()
73
74
75


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
76
    def __init__(self, matched: Union[int, List[int]]):
77
78
79
        super().__init__()
        self.matched = matched

80
81
82
83
84
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
85
86


87
88
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
89
        super().__init__()
90
        self.matched = matched
91

92
93
94
95
96
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
97
98


99
100
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
101
        super().__init__()
102
        self.length = length
103

104
105
106
107
108
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
109
110
111


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
112
    def __init__(self, message="Unknown error"):
113
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        self.message = message
115

116
117
118
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
119
            "message": self.message,
120
        }
121

Lianmin Zheng's avatar
Lianmin Zheng committed
122

123
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
124
class ImageInputs:
125
126
    """The image related inputs."""

127
    pixel_values: Union[torch.Tensor, np.array]
128
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
129
130
131
132
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
133
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
134

135
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
138

Yineng Zhang's avatar
Yineng Zhang committed
139
140
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
141
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143

    @staticmethod
144
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
147
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
148
        )
149
150
151

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
152
153
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
154
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
155
156
157
158
159
160
161
162
163
164
165
166

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
        return ret

169
    def merge(self, other):
170
171
172
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

173
174
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
175
176
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
177
178
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
179
180
181
182
183
184
185
186
187
188
189
190
191

        optional_args = [
            "image_sizes",
            "image_offsets",
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
192

Lianmin Zheng's avatar
Lianmin Zheng committed
193
class Req:
194
    """The input and output status of a request."""
195

196
197
198
199
200
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
201
        sampling_params: SamplingParams,
202
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
203
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
204
        input_embeds: Optional[List[List[float]]] = None,
205
        session_id: Optional[str] = None,
206
    ):
207
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
208
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
209
        self.origin_input_text = origin_input_text
210
211
212
213
214
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
215
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
216
        self.output_ids = []  # Each decode stage's output ids
217
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
218
219
        self.session_id = session_id

220
        self.sampling_params = sampling_params
221
        self.lora_path = lora_path
Rin Intachuen's avatar
Rin Intachuen committed
222
        self.input_embeds = input_embeds
Liangsheng Yin's avatar
Liangsheng Yin committed
223

224
        # Memory pool info
225
226
        self.req_pool_idx = None

227
228
229
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
230
        self.stream = False
231
        self.to_abort = False
232

233
        # For incremental decoding
234
235
236
237
238
239
240
241
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
242
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
243
244
245
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
246

247
248
249
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
250

251
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
252
        self.image_inputs: Optional[ImageInputs] = None
253

254
255
        # Prefix info
        self.prefix_indices = []
256
        self.extend_input_len = 0
257
        self.last_node = None
258
        self.is_being_chunked = 0
259

260
261
262
        # For retraction
        self.is_retracted = False

263
        # Logprobs (arguments)
264
265
266
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
267
268

        # Logprobs (return value)
269
        self.normalized_prompt_logprob = None
270
271
272
273
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
274
275

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
279
280
281
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

282
        # Embedding (return values)
283
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
284

285
        # Constrained decoding
286
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
287

288
289
290
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

291
    def extend_image_inputs(self, image_inputs):
292
293
294
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
295
            self.image_inputs.merge(image_inputs)
296

297
298
299
300
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

301
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
302
        self.fill_ids = self.origin_input_ids + self.output_ids
303
304
305
306
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
307
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
308

309
    def adjust_max_prefix_ids(self):
310
311
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
312
313
314
315

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
316
317
318
319
320

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

321
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
322
323
324
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
325
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
326

327
        max_prefix_len = max(max_prefix_len, 0)
328
        return self.fill_ids[:max_prefix_len]
329

Liangsheng Yin's avatar
Liangsheng Yin committed
330
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
331
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
332
333
334
335
336
337
338
339
340
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
341
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
342

343
    def get_next_inc_detokenization(self):
344
345
        if self.tokenizer is None:
            return False, ""
346
347
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
348
349
350
351
352

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
353
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
354
355
356
357
358
359
360
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
361
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
362
363

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
364

365
    def check_finished(self):
366
        if self.finished():
367
368
            return

369
370
371
372
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
373
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
374
375
376
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
377
378
            return

379
        last_token_id = self.output_ids[-1]
380

381
        matched_eos = False
382

383
384
385
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
386
387
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
388
389
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
390
        if matched_eos and not self.sampling_params.ignore_eos:
391
392
393
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

394
        # Check stop strings
395
396
397
398
399
400
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
401
                if stop_str in tail_str or stop_str in self.decoded_text:
402
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
403
404
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
405
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
406
407
408
409
410
411
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
412
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
413
        all_ids = self.tokenizer.encode(all_text)
414
        if not all_ids:
havetc's avatar
havetc committed
415
            logger.warning("Encoded all_text resulted in empty all_ids")
416
417
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
418
        prompt_tokens = len(self.origin_input_ids_unpadded)
419
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
420
            logger.warning("prompt_tokens is larger than encoded all_ids")
421
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
422
423
424

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
425
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
444

445
446
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
447
448
449
450

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
451
452
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
453
454
455
                    k = k + 1
                else:
                    break
456
457
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
458
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
459
            self.last_update_decode_tokens = len(self.output_ids) - k
460

Liangsheng Yin's avatar
Liangsheng Yin committed
461
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
462

Lianmin Zheng's avatar
Lianmin Zheng committed
463
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
464
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
465
466


467
468
469
bid = 0


470
@dataclasses.dataclass
471
class ScheduleBatch:
472
    """Store all inforamtion of a batch on the scheduler."""
473

474
    # Request, memory pool, and cache
475
    reqs: List[Req]
476
477
478
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
479

480
    # Batch configs
481
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
482
    forward_mode: ForwardMode = None
483
484
485
    enable_overlap: bool = False

    # Sampling info
486
    sampling_info: SamplingBatchInfo = None
487
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
488

489
    # Batched arguments to model runner
490
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
491
    input_embeds: torch.Tensor = None
492
493
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
494
    # The output locations of the KV cache
495
    out_cache_loc: torch.Tensor = None
496
497
    output_ids: torch.Tensor = None

498
499
500
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
501
502
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
503
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
504

505
    # For processing logprobs
506
    return_logprob: bool = False
507
508
509
510
511
512
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
513
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
514
    extend_logprob_start_lens: List[int] = None
515

516
517
518
519
520
521
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

522
523
524
    # Stream
    has_stream: bool = False

525
526
    # Has grammar
    has_grammar: bool = False
527

528
529
530
    # device
    device: str = "cuda"

531
    @classmethod
532
533
    def init_new(
        cls,
534
        reqs: List[Req],
535
536
537
538
539
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
540
    ):
541
542
543
544
545
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
546
            model_config=model_config,
547
            enable_overlap=enable_overlap,
548
549
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
550
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
551
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
552
553
        )

554
    def batch_size(self):
555
        return len(self.reqs)
556

Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
559
    def is_empty(self):
        return len(self.reqs) == 0

560
    def alloc_req_slots(self, num_reqs: int):
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
578
579
580
581
582
583
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
584
585
586
587
588
589
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
620
                # NOTE: the encoder part should be considered as a whole
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
646
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
647
648
649
650
651
652
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
653
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
654
655
656
657
658
659
660
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

661
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
662
663
        self.forward_mode = ForwardMode.EXTEND

664
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
665
        reqs = self.reqs
666
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
667
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
668
        seq_lens = []
669
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
670

671
        # Allocate memory
672
        req_pool_indices = self.alloc_req_slots(bs)
673
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
674

Rin Intachuen's avatar
Rin Intachuen committed
675
676
677
        input_embeds = []

        pt = 0
678
        for i, req in enumerate(reqs):
679
680
681
682
683
684
685
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

686
            req.req_pool_idx = req_pool_indices[i]
687
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
688
            seq_lens.append(seq_len)
689
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
690

691
            if pre_len > 0:
692
693
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
694
                )
695

Rin Intachuen's avatar
Rin Intachuen committed
696
697
698
699
700
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

701
702
703
704
705
706
707
708
709
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
710
            req.is_retracted = False
711
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
712
713

        # Set fields
714
715
716
717
718
719
720
721
722
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
723
724
725
726
727
728
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
729
        self.out_cache_loc = out_cache_loc
730
731

        self.seq_lens_sum = sum(seq_lens)
732
733
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
734
        self.extend_num_tokens = extend_num_tokens
735
736
737
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
738

739
740
741
742
743
744
745
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
764
765
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

766
767
768
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

769
        # Build sampling info
770
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
771
772
            self,
            self.model_config.vocab_size,
773
            enable_overlap_schedule=self.enable_overlap,
774
        )
775

776
    def mix_with_running(self, running_batch: "ScheduleBatch"):
777
        self.forward_mode = ForwardMode.MIXED
778
        running_bs = running_batch.batch_size()
779
780
781
782
783

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

784
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
785
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
786

787
        self.merge_batch(running_batch)
788
789
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
790

791
792
793
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

794
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
795
        self.prefix_lens.extend(
796
            [
797
                len(r.origin_input_ids) + len(r.output_ids) + delta
798
799
800
                for r in running_batch.reqs
            ]
        )
801
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
804
        self.extend_logprob_start_lens.extend([0] * running_bs)
805

806
    def check_decode_mem(self):
807
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
808
        if self.token_to_kv_pool.available_size() >= bs:
809
810
            return True

Mingyi's avatar
Mingyi committed
811
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
812

813
814
815
816
817
818
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
819
        """Retract the decoding requests when there is not enough memory."""
820
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
821
822

        # TODO(lsyin): improve retraction policy for radix cache
823
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
824
825
826
827
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
828
829
830
831
            reverse=True,
        )

        retracted_reqs = []
832
        seq_lens_cpu = self.seq_lens.cpu().numpy()
833
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
834
835
836
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
837
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
838
839
840
841
842
843
844
845
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

846
            first_iter = False
847
848
849
850
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

851
852
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
853
854
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
855
                ]
856
                self.token_to_kv_pool.free(token_indices)
857
                self.req_to_token_pool.free(req.req_pool_idx)
858
859
860
861
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
862
863
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
864
                ]
865
                self.token_to_kv_pool.free(token_indices)
866
                self.req_to_token_pool.free(req.req_pool_idx)
867
868
869
870
871
872
873
874
875
876
877

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
878

879
            req.prefix_indices = []
880
            req.last_node = None
881
            req.extend_input_len = 0
882
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
883
884
885
886

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
887

888
        self.filter_batch(keep_indices=sorted_indices)
889

Liangsheng Yin's avatar
Liangsheng Yin committed
890
891
892
893
894
895
896
897
898
899
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
900

901
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
902
        jump_forward_reqs = []
903
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
904
905

        for i, req in enumerate(self.reqs):
906
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
907
908
909
910
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
911
912
913
914
915
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
916
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
917
918
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
919
920
                        continue

sglang's avatar
sglang committed
921
922
923
                    (
                        jump_forward_str,
                        next_state,
924
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
925

Lianmin Zheng's avatar
Lianmin Zheng committed
926
927
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
928
929
930
931
932
933
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
934

935
936
937
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
938
                    # insert the old request into tree_cache
939
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
940

Liangsheng Yin's avatar
Liangsheng Yin committed
941
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
942
                    if req.image_inputs is not None:
943
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
944
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
945
946
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
947
                    jump_forward_reqs.append(req)
948
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
949

950
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
951

Liangsheng Yin's avatar
Liangsheng Yin committed
952
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
953

954
955
956
957
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
958
959
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
960
961
962
963
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
964
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
965
966
        self.extend_num_tokens = 0

967
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
968
969
        self.forward_mode = ForwardMode.DECODE

970
971
        self.input_ids = self.output_ids
        self.output_ids = None
972
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
973
974

        # Alloc mem
975
        bs = len(self.reqs)
976
        self.out_cache_loc = self.alloc_token_slots(bs)
977

978
979
980
981
982
983
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

984
        if self.enable_overlap:
985
986
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
987
                (self.req_pool_indices, locs), self.out_cache_loc
988
989
990
991
992
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
993
                (self.req_pool_indices, locs), self.out_cache_loc
994
995
            )
            self.seq_lens.add_(1)
996
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
997

998
999
    def filter_batch(
        self,
1000
        being_chunked_req: Optional[Req] = None,
1001
1002
1003
1004
1005
1006
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1007
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1008
1009
1010
            ]

        if keep_indices is None or len(keep_indices) == 0:
1011
1012
1013
1014
            # Filter out all requests
            self.reqs = []
            return

1015
        if len(keep_indices) == len(self.reqs):
1016
1017
1018
            # No need to filter
            return

1019
1020
1021
1022
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1023
        self.reqs = [self.reqs[i] for i in keep_indices]
1024
1025
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
1026
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1027
        self.req_pool_indices = self.req_pool_indices[new_indices]
1028
        self.seq_lens = self.seq_lens[new_indices]
1029
        self.out_cache_loc = None
1030
        self.seq_lens_sum = self.seq_lens.sum().item()
1031
        self.output_ids = self.output_ids[new_indices]
1032
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1033
        if self.return_logprob:
1034
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1035
1036
        else:
            self.top_logprobs_nums = None
1037

1038
        self.has_stream = any(req.stream for req in self.reqs)
1039
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1040

1041
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1042

1043
    def merge_batch(self, other: "ScheduleBatch"):
1044
1045
1046
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1047
        self.sampling_info.merge_batch(other.sampling_info)
1048

1049
1050
1051
1052
1053
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1054
1055
1056
1057
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1058
        self.out_cache_loc = None
1059
        self.seq_lens_sum += other.seq_lens_sum
1060
1061
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1062
1063
1064
1065
1066
1067
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1068
        self.reqs.extend(other.reqs)
1069

1070
        self.return_logprob = self.return_logprob or other.return_logprob
1071
        self.has_stream = self.has_stream or other.has_stream
1072
        self.has_grammar = self.has_grammar or other.has_grammar
1073
1074

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1075
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1076
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1077
1078
1079
1080
1081
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1082
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1083
1084
1085
1086
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1087

1088
1089
1090
        global bid
        bid += 1

1091
        return ModelWorkerBatch(
1092
            bid=bid,
1093
1094
1095
1096
1097
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1098
            seq_lens_sum=self.seq_lens_sum,
1099
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1100
1101
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1102
            global_num_tokens=self.global_num_tokens,
1103
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1104
            extend_num_tokens=self.extend_num_tokens,
1105
1106
1107
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1108
1109
1110
1111
1112
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1113
            lora_paths=[req.lora_path for req in self.reqs],
1114
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1115
            input_embeds=self.input_embeds,
1116
1117
        )

1118
    def copy(self):
1119
        # Only contain fields that will be used by process_batch_result
1120
1121
        return ScheduleBatch(
            reqs=self.reqs,
1122
            model_config=self.model_config,
1123
            forward_mode=self.forward_mode,
1124
1125
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1126
            decoding_reqs=self.decoding_reqs,
1127
1128
1129
1130
1131
1132
1133
1134
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1135

1136
@dataclasses.dataclass
1137
class ModelWorkerBatch:
1138
1139
    # The batch id
    bid: int
1140
1141
1142
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1143
    input_ids: torch.Tensor
1144
1145
1146
1147
1148
1149
1150
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1151
1152
1153
    # The sum of all sequence lengths
    seq_lens_sum: int

1154
1155
1156
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1157
1158
1159
1160
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1161
1162
    # For DP attention
    global_num_tokens: Optional[List[int]]
1163
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1164

1165
    # For extend
1166
    extend_num_tokens: Optional[int]
1167
1168
1169
1170
1171
1172
1173
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1174
1175
1176
1177
1178
1179
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1180
1181
1182
1183
1184
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1185

Rin Intachuen's avatar
Rin Intachuen committed
1186
1187
1188
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )