schedule_batch.py 44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
23
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
24
25
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
26
27
28
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
29

30
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
31
import logging
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
35
import torch
36
37
import triton
import triton.language as tl
38

Liangsheng Yin's avatar
Liangsheng Yin committed
39
from sglang.global_config import global_config
40
from sglang.srt.configs.model_config import ModelConfig
41
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
42
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
45
from sglang.srt.model_executor.forward_batch_info import ForwardMode
46
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
47
from sglang.srt.sampling.sampling_params import SamplingParams
48
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
49
50

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
51

52
53
# Put some global args for easy access
global_server_args_dict = {
54
55
56
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
57
    "disable_mla": ServerArgs.disable_mla,
58
    "torchao_config": ServerArgs.torchao_config,
59
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
60
    "enable_dp_attention": ServerArgs.enable_dp_attention,
61
62
}

Lianmin Zheng's avatar
Lianmin Zheng committed
63

Ying Sheng's avatar
Ying Sheng committed
64
65
66
logger = logging.getLogger(__name__)


67
68
69
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
70

71
    def to_json(self):
72
        raise NotImplementedError()
73
74
75


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
76
    def __init__(self, matched: Union[int, List[int]]):
77
78
79
        super().__init__()
        self.matched = matched

80
81
82
83
84
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
85
86


87
88
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
89
        super().__init__()
90
        self.matched = matched
91

92
93
94
95
96
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
97
98


99
100
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
101
        super().__init__()
102
        self.length = length
103

104
105
106
107
108
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
109
110
111


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
112
    def __init__(self, message="Unknown error"):
113
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
114
        self.message = message
115

116
117
118
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
119
            "message": self.message,
120
        }
121

Lianmin Zheng's avatar
Lianmin Zheng committed
122

123
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
124
class ImageInputs:
125
126
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
127
    pixel_values: torch.Tensor
128
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
129
130
131
132
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
133
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
134
135
136
137

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
138

Yineng Zhang's avatar
Yineng Zhang committed
139
140
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
141
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
144
145
146
147

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
148
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
149
        )
150
151
152
153
154
155
156
157
158
        if not isinstance(ret.image_hashes, list):
            ret.pad_values = [
                (ret.image_hashes) % vocab_size,
                (ret.image_hashes >> 16) % vocab_size,
                (ret.image_hashes >> 32) % vocab_size,
                (ret.image_hashes >> 64) % vocab_size,
            ]
        else:
            ret.pad_values = [x % vocab_size for x in ret.image_hashes]
159
160
161
162
163
164
165
166
167
168
169
170

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
171
172
        return ret

173
174
175
176
    def merge(self, other, vocab_size):
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

177
178
179
180
181
182
183
184
185
186
187
        if isinstance(self.image_hashes, list) and isinstance(other.image_hashes, list):
            self.image_hashes += other.image_hashes
            self.pad_values = [x % vocab_size for x in self.image_hashes]
        else:
            self.image_hashes = hash(tuple(self.image_hashes, other.image_hashes))
            self.pad_values = [
                (self.image_hashes) % vocab_size,
                (self.image_hashes >> 16) % vocab_size,
                (self.image_hashes >> 32) % vocab_size,
                (self.image_hashes >> 64) % vocab_size,
            ]
188
189
190
191
192
193
194
195
196
197
198
199
200

        optional_args = [
            "image_sizes",
            "image_offsets",
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
201

Lianmin Zheng's avatar
Lianmin Zheng committed
202
class Req:
203
    """The input and output status of a request."""
204

205
206
207
208
209
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
210
        sampling_params: SamplingParams,
211
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
212
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
213
        input_embeds: Optional[List[List[float]]] = None,
214
        session_id: Optional[str] = None,
215
    ):
216
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
217
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
218
        self.origin_input_text = origin_input_text
219
220
221
222
223
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
224
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
225
        self.output_ids = []  # Each decode stage's output ids
226
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
227
228
        self.session_id = session_id

229
        self.sampling_params = sampling_params
230
        self.lora_path = lora_path
Rin Intachuen's avatar
Rin Intachuen committed
231
        self.input_embeds = input_embeds
Liangsheng Yin's avatar
Liangsheng Yin committed
232

233
        # Memory pool info
234
235
        self.req_pool_idx = None

236
237
238
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
239
        self.stream = False
240
        self.to_abort = False
241

242
        # For incremental decoding
243
244
245
246
247
248
249
250
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
251
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
252
253
254
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
255

256
257
258
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
259

260
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
261
        self.image_inputs: Optional[ImageInputs] = None
262

263
264
        # Prefix info
        self.prefix_indices = []
265
        self.extend_input_len = 0
266
        self.last_node = None
267
        self.is_being_chunked = 0
268

269
270
271
        # For retraction
        self.is_retracted = False

272
        # Logprobs (arguments)
273
274
275
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
276
277

        # Logprobs (return value)
278
        self.normalized_prompt_logprob = None
279
280
281
282
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
283
284

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
285
286
287
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
288
289
290
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

291
        # Embedding (return values)
292
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
293

294
        # Constrained decoding
295
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
296

297
298
299
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

300
301
302
303
304
305
    def extend_image_inputs(self, image_inputs, vocab_size):
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
            self.image_inputs.merge(image_inputs, vocab_size)

306
307
308
309
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

310
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
311
        self.fill_ids = self.origin_input_ids + self.output_ids
312
313
314
315
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
316
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
317

318
    def adjust_max_prefix_ids(self):
319
320
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
321
322
323
324

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
325
326
327
328
329

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

330
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
331
332
333
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
334
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
335

336
        max_prefix_len = max(max_prefix_len, 0)
337
        return self.fill_ids[:max_prefix_len]
338

Liangsheng Yin's avatar
Liangsheng Yin committed
339
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
340
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
341
342
343
344
345
346
347
348
349
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
350
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
351

352
    def get_next_inc_detokenization(self):
353
354
        if self.tokenizer is None:
            return False, ""
355
356
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
357
358
359
360
361

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
362
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
363
364
365
366
367
368
369
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
370
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
371
372

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
373

374
    def check_finished(self):
375
        if self.finished():
376
377
            return

378
379
380
381
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
382
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
383
384
385
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
386
387
            return

388
        last_token_id = self.output_ids[-1]
389

390
        matched_eos = False
391

392
393
394
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
395
396
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
397
398
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
399
        if matched_eos and not self.sampling_params.ignore_eos:
400
401
402
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

403
        # Check stop strings
404
405
406
407
408
409
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
410
                if stop_str in tail_str or stop_str in self.decoded_text:
411
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
412
413
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
414
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
415
416
417
418
419
420
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
421
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
422
        all_ids = self.tokenizer.encode(all_text)
423
        if not all_ids:
havetc's avatar
havetc committed
424
            logger.warning("Encoded all_text resulted in empty all_ids")
425
426
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
427
        prompt_tokens = len(self.origin_input_ids_unpadded)
428
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
429
            logger.warning("prompt_tokens is larger than encoded all_ids")
430
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
431
432
433

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
434
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
453

454
455
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
456
457
458
459

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
460
461
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
462
463
464
                    k = k + 1
                else:
                    break
465
466
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
467
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
468
            self.last_update_decode_tokens = len(self.output_ids) - k
469

Liangsheng Yin's avatar
Liangsheng Yin committed
470
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
471

Lianmin Zheng's avatar
Lianmin Zheng committed
472
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
473
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
474
475


476
477
478
bid = 0


479
@dataclasses.dataclass
480
class ScheduleBatch:
481
    """Store all inforamtion of a batch on the scheduler."""
482

483
    # Request, memory pool, and cache
484
    reqs: List[Req]
485
486
487
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
488

489
    # Batch configs
490
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
491
    forward_mode: ForwardMode = None
492
493
494
    enable_overlap: bool = False

    # Sampling info
495
    sampling_info: SamplingBatchInfo = None
496
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
497

498
    # Batched arguments to model runner
499
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
500
    input_embeds: torch.Tensor = None
501
502
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
503
    # The output locations of the KV cache
504
    out_cache_loc: torch.Tensor = None
505
506
    output_ids: torch.Tensor = None

507
508
509
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
510
511
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
512
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
513

514
    # For processing logprobs
515
    return_logprob: bool = False
516
517
518
519
520
521
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
522
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
523
    extend_logprob_start_lens: List[int] = None
524

525
526
527
528
529
530
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

531
532
533
    # Stream
    has_stream: bool = False

534
535
    # Has grammar
    has_grammar: bool = False
536

537
538
539
    # device
    device: str = "cuda"

540
    @classmethod
541
542
    def init_new(
        cls,
543
        reqs: List[Req],
544
545
546
547
548
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
549
    ):
550
551
552
553
554
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
555
            model_config=model_config,
556
            enable_overlap=enable_overlap,
557
558
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
559
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
560
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
        )

563
    def batch_size(self):
564
        return len(self.reqs)
565

Lianmin Zheng's avatar
Lianmin Zheng committed
566
567
568
    def is_empty(self):
        return len(self.reqs) == 0

569
    def alloc_req_slots(self, num_reqs: int):
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
587
588
589
590
591
592
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
593
594
595
596
597
598
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
629
                # NOTE: the encoder part should be considered as a whole
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
655
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
656
657
658
659
660
661
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
662
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
663
664
665
666
667
668
669
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

670
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
671
672
        self.forward_mode = ForwardMode.EXTEND

673
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
674
        reqs = self.reqs
675
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
676
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
677
        seq_lens = []
678
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
679

680
        # Allocate memory
681
        req_pool_indices = self.alloc_req_slots(bs)
682
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
683

Rin Intachuen's avatar
Rin Intachuen committed
684
685
686
        input_embeds = []

        pt = 0
687
        for i, req in enumerate(reqs):
688
689
690
691
692
693
694
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

695
            req.req_pool_idx = req_pool_indices[i]
696
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
697
            seq_lens.append(seq_len)
698
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
699

700
            if pre_len > 0:
701
702
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
703
                )
704

Rin Intachuen's avatar
Rin Intachuen committed
705
706
707
708
709
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

710
711
712
713
714
715
716
717
718
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
719
            req.is_retracted = False
720
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722

        # Set fields
723
724
725
726
727
728
729
730
731
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
732
733
734
735
736
737
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
738
        self.out_cache_loc = out_cache_loc
739
740

        self.seq_lens_sum = sum(seq_lens)
741
742
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
743
        self.extend_num_tokens = extend_num_tokens
744
745
746
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
747

748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        write_req_to_token_pool_triton[(bs,)](
            self.req_to_token_pool.req_to_token,
            self.req_pool_indices,
            pre_lens,
            self.seq_lens,
            extend_lens,
            self.out_cache_loc,
            self.req_to_token_pool.req_to_token.shape[1],
        )
        # The triton kernel is equivalent to the following python code.
        # self.req_to_token_pool.write(
        #    (req.req_pool_idx, slice(pre_len, seq_len)),
        #    out_cache_loc[pt : pt + req.extend_input_len],
        # )
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

771
772
773
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

774
        # Build sampling info
775
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
776
777
            self,
            self.model_config.vocab_size,
778
            enable_overlap_schedule=self.enable_overlap,
779
        )
780

781
    def mix_with_running(self, running_batch: "ScheduleBatch"):
782
        self.forward_mode = ForwardMode.MIXED
783
        running_bs = running_batch.batch_size()
784
785
786
787
788

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

789
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
790
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
791

792
        self.merge_batch(running_batch)
793
794
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
795

796
797
798
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

799
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
800
        self.prefix_lens.extend(
801
            [
802
                len(r.origin_input_ids) + len(r.output_ids) + delta
803
804
805
                for r in running_batch.reqs
            ]
        )
806
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
807
808
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
809
        self.extend_logprob_start_lens.extend([0] * running_bs)
810

811
    def check_decode_mem(self):
812
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
813
        if self.token_to_kv_pool.available_size() >= bs:
814
815
            return True

Mingyi's avatar
Mingyi committed
816
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
817

818
819
820
821
822
823
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
824
        """Retract the decoding requests when there is not enough memory."""
825
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
826
827

        # TODO(lsyin): improve retraction policy for radix cache
828
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
829
830
831
832
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
833
834
835
836
            reverse=True,
        )

        retracted_reqs = []
837
        seq_lens_cpu = self.seq_lens.cpu().numpy()
838
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
839
840
841
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
842
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
843
844
845
846
847
848
849
850
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

851
            first_iter = False
852
853
854
855
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

856
857
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
858
859
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
860
                ]
861
                self.token_to_kv_pool.free(token_indices)
862
                self.req_to_token_pool.free(req.req_pool_idx)
863
864
865
866
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
867
868
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
869
                ]
870
                self.token_to_kv_pool.free(token_indices)
871
                self.req_to_token_pool.free(req.req_pool_idx)
872
873
874
875
876
877
878
879
880
881
882

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
883

884
            req.prefix_indices = []
885
            req.last_node = None
886
            req.extend_input_len = 0
887
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
888
889
890
891

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
892

893
        self.filter_batch(keep_indices=sorted_indices)
894

Liangsheng Yin's avatar
Liangsheng Yin committed
895
896
897
898
899
900
901
902
903
904
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
905

906
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
907
        jump_forward_reqs = []
908
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
909
910

        for i, req in enumerate(self.reqs):
911
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
912
913
914
915
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
916
917
918
919
920
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
921
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
922
923
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
924
925
                        continue

sglang's avatar
sglang committed
926
927
928
                    (
                        jump_forward_str,
                        next_state,
929
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
930

Lianmin Zheng's avatar
Lianmin Zheng committed
931
932
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
933
934
935
936
937
938
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
939

940
941
942
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
943
                    # insert the old request into tree_cache
944
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
945

Liangsheng Yin's avatar
Liangsheng Yin committed
946
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
947
                    if req.image_inputs is not None:
948
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
949
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
950
951
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
952
                    jump_forward_reqs.append(req)
953
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
954

955
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
956

Liangsheng Yin's avatar
Liangsheng Yin committed
957
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
958

959
960
961
962
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
963
964
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
965
966
967
968
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
969
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
970
971
        self.extend_num_tokens = 0

972
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
973
974
        self.forward_mode = ForwardMode.DECODE

975
976
        self.input_ids = self.output_ids
        self.output_ids = None
977
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
978
979

        # Alloc mem
980
        bs = len(self.reqs)
981
        self.out_cache_loc = self.alloc_token_slots(bs)
982

983
984
985
986
987
988
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

989
        if self.enable_overlap:
990
991
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
992
                (self.req_pool_indices, locs), self.out_cache_loc
993
994
995
996
997
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
998
                (self.req_pool_indices, locs), self.out_cache_loc
999
1000
            )
            self.seq_lens.add_(1)
1001
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1002

1003
1004
    def filter_batch(
        self,
1005
        being_chunked_req: Optional[Req] = None,
1006
1007
1008
1009
1010
1011
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1012
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1013
1014
1015
            ]

        if keep_indices is None or len(keep_indices) == 0:
1016
1017
1018
1019
            # Filter out all requests
            self.reqs = []
            return

1020
        if len(keep_indices) == len(self.reqs):
1021
1022
1023
            # No need to filter
            return

1024
1025
1026
1027
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1028
        self.reqs = [self.reqs[i] for i in keep_indices]
1029
1030
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
1031
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
        self.req_pool_indices = self.req_pool_indices[new_indices]
1033
        self.seq_lens = self.seq_lens[new_indices]
1034
        self.out_cache_loc = None
1035
        self.seq_lens_sum = self.seq_lens.sum().item()
1036
        self.output_ids = self.output_ids[new_indices]
1037
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1038
        if self.return_logprob:
1039
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1040
1041
        else:
            self.top_logprobs_nums = None
1042

1043
        self.has_stream = any(req.stream for req in self.reqs)
1044
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1045

1046
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1047

1048
    def merge_batch(self, other: "ScheduleBatch"):
1049
1050
1051
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1052
        self.sampling_info.merge_batch(other.sampling_info)
1053

1054
1055
1056
1057
1058
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1059
1060
1061
1062
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1063
        self.out_cache_loc = None
1064
        self.seq_lens_sum += other.seq_lens_sum
1065
1066
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1067
1068
1069
1070
1071
1072
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1073
        self.reqs.extend(other.reqs)
1074

1075
        self.return_logprob = self.return_logprob or other.return_logprob
1076
        self.has_stream = self.has_stream or other.has_stream
1077
        self.has_grammar = self.has_grammar or other.has_grammar
1078
1079

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1080
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1081
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1082
1083
1084
1085
1086
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1087
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1088
1089
1090
1091
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1092

1093
1094
1095
        global bid
        bid += 1

1096
        return ModelWorkerBatch(
1097
            bid=bid,
1098
1099
1100
1101
1102
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1103
            seq_lens_sum=self.seq_lens_sum,
1104
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1105
1106
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1107
            global_num_tokens=self.global_num_tokens,
1108
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1109
            extend_num_tokens=self.extend_num_tokens,
1110
1111
1112
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1113
1114
1115
1116
1117
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1118
            lora_paths=[req.lora_path for req in self.reqs],
1119
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1120
            input_embeds=self.input_embeds,
1121
1122
        )

1123
    def copy(self):
1124
        # Only contain fields that will be used by process_batch_result
1125
1126
        return ScheduleBatch(
            reqs=self.reqs,
1127
            model_config=self.model_config,
1128
            forward_mode=self.forward_mode,
1129
1130
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1131
            decoding_reqs=self.decoding_reqs,
1132
1133
1134
1135
1136
1137
1138
1139
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1140

1141
@dataclasses.dataclass
1142
class ModelWorkerBatch:
1143
1144
    # The batch id
    bid: int
1145
1146
1147
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1148
    input_ids: torch.Tensor
1149
1150
1151
1152
1153
1154
1155
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1156
1157
1158
    # The sum of all sequence lengths
    seq_lens_sum: int

1159
1160
1161
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1162
1163
1164
1165
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1166
1167
    # For DP attention
    global_num_tokens: Optional[List[int]]
1168
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1169

1170
    # For extend
1171
    extend_num_tokens: Optional[int]
1172
1173
1174
1175
1176
1177
1178
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1179
1180
1181
1182
1183
1184
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1185
1186
1187
1188
1189
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1190

Rin Intachuen's avatar
Rin Intachuen committed
1191
1192
1193
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )