schedule_batch.py 44.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
23
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
24
25
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
26
27
28
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
29

30
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
31
import logging
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
35
import torch
36
37
import triton
import triton.language as tl
38

Liangsheng Yin's avatar
Liangsheng Yin committed
39
from sglang.global_config import global_config
40
from sglang.srt.configs.model_config import ModelConfig
41
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
42
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
45
from sglang.srt.model_executor.forward_batch_info import ForwardMode
46
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
47
from sglang.srt.sampling.sampling_params import SamplingParams
48
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
49
50

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
51

52
53
# Put some global args for easy access
global_server_args_dict = {
54
55
56
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
57
    "disable_mla": ServerArgs.disable_mla,
58
    "torchao_config": ServerArgs.torchao_config,
59
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
60
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
61
    "enable_ep_moe": ServerArgs.enable_ep_moe,
62
63
}

Lianmin Zheng's avatar
Lianmin Zheng committed
64

Ying Sheng's avatar
Ying Sheng committed
65
66
67
logger = logging.getLogger(__name__)


68
69
70
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
71

72
    def to_json(self):
73
        raise NotImplementedError()
74
75
76


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
77
    def __init__(self, matched: Union[int, List[int]]):
78
79
80
        super().__init__()
        self.matched = matched

81
82
83
84
85
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
86
87


88
89
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
90
        super().__init__()
91
        self.matched = matched
92

93
94
95
96
97
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
98
99


100
101
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
102
        super().__init__()
103
        self.length = length
104

105
106
107
108
109
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
110
111
112


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    def __init__(self, message="Unknown error"):
114
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        self.message = message
116

117
118
119
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            "message": self.message,
121
        }
122

Lianmin Zheng's avatar
Lianmin Zheng committed
123

124
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
125
class ImageInputs:
126
127
    """The image related inputs."""

128
    pixel_values: Union[torch.Tensor, np.array]
129
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
132
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
135
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
136

137
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
138
139
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
140

Yineng Zhang's avatar
Yineng Zhang committed
141
142
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
143
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
144
145

    @staticmethod
146
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
149
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
150
        )
151
152
153

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
154
155
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
156
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
157
158
159
160
161
162
163
164
165
166
167
168

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
169
170
        return ret

171
    def merge(self, other):
172
173
174
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

175
176
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
177
178
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
179
180
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
181
182
183
184

        optional_args = [
            "image_sizes",
            "image_offsets",
185
            "image_pad_len",
186
187
188
189
190
191
192
193
194
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
195

Lianmin Zheng's avatar
Lianmin Zheng committed
196
class Req:
197
    """The input and output status of a request."""
198

199
200
201
202
203
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
204
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
205
206
207
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
        stream: bool = False,
208
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
209
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
210
        input_embeds: Optional[List[List[float]]] = None,
211
        session_id: Optional[str] = None,
212
    ):
213
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
214
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
215
        self.origin_input_text = origin_input_text
216
217
218
219
220
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
221
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
222
        self.output_ids = []  # Each decode stage's output ids
223
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
224
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
225
        self.input_embeds = input_embeds
226

Lianmin Zheng's avatar
Lianmin Zheng committed
227
        # Sampling info
228
        self.sampling_params = sampling_params
229
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
230

231
        # Memory pool info
232
233
        self.req_pool_idx = None

234
235
236
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
237
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
238
        self.stream = stream
239

240
        # For incremental decoding
241
242
243
244
245
246
247
248
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
249
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
250
251
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
252
        self.decoded_text = ""
253

254
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
255
        self.image_inputs: Optional[ImageInputs] = None
256

257
258
        # Prefix info
        self.prefix_indices = []
259
        # Tokens to run prefill. input_tokens - shared_prefix_tokens.
260
        self.extend_input_len = 0
261
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263

        # Chunked prefill
264
        self.is_being_chunked = 0
265

266
267
268
        # For retraction
        self.is_retracted = False

269
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
270
        self.return_logprob = return_logprob
271
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
272
        self.top_logprobs_num = top_logprobs_num
273
274

        # Logprobs (return value)
275
        self.normalized_prompt_logprob = None
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        self.input_token_logprobs_val = None
        self.input_token_logprobs_idx = None
        self.input_top_logprobs_val = None
        self.input_top_logprobs_idx = None

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
            ) = self.output_top_logprobs_idx = None
290
291

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
295
296
297
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

298
        # Embedding (return values)
299
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
300

301
        # Constrained decoding
302
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
303

304
305
306
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

307
    def extend_image_inputs(self, image_inputs):
308
309
310
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
311
            self.image_inputs.merge(image_inputs)
312

313
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
314
        # Whether request reached finished condition
315
316
        return self.finished_reason is not None

317
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
318
        self.fill_ids = self.origin_input_ids + self.output_ids
319
        if tree_cache is not None:
320
            # tree cache is None if the prefix is not computed with tree cache.
321
322
323
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
324
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
325

326
    def adjust_max_prefix_ids(self):
327
328
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
329
330
331
332

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
333
334
335
336
337

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

338
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
339
340
341
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
342
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
343

344
        max_prefix_len = max(max_prefix_len, 0)
345
        return self.fill_ids[:max_prefix_len]
346

Liangsheng Yin's avatar
Liangsheng Yin committed
347
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
348
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
349
350
351
352
353
354
355
356
357
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
358
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
359

360
    def get_next_inc_detokenization(self):
361
362
        if self.tokenizer is None:
            return False, ""
363
364
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
365
366
367
368
369

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
370
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
371
372
373
374
375
376
377
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
378
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
379
380

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
381

382
    def check_finished(self):
383
        if self.finished():
384
385
            return

386
387
388
389
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
390
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
391
392
393
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
394
395
            return

396
        last_token_id = self.output_ids[-1]
397

398
        matched_eos = False
399

400
401
402
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
403
404
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
405
406
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
407
        if matched_eos and not self.sampling_params.ignore_eos:
408
409
410
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

411
        # Check stop strings
412
413
414
415
416
417
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
418
                if stop_str in tail_str or stop_str in self.decoded_text:
419
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
420
421
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
422
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
423
424
425
426
427
428
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
429
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
430
        all_ids = self.tokenizer.encode(all_text)
431
        if not all_ids:
havetc's avatar
havetc committed
432
            logger.warning("Encoded all_text resulted in empty all_ids")
433
434
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
435
        prompt_tokens = len(self.origin_input_ids_unpadded)
436
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
437
            logger.warning("prompt_tokens is larger than encoded all_ids")
438
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
439
440
441

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
442
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
461

462
463
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
464
465
466
467

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
468
469
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
470
471
472
                    k = k + 1
                else:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
475
476
            self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
            self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
            self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
            self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
477
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
478
            self.last_update_decode_tokens = len(self.output_ids) - k
479

Liangsheng Yin's avatar
Liangsheng Yin committed
480
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
481

482
483
484
485
486
487
488
489
490
491
492
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True

        # For incremental logprobs
        # TODO: Fix the `logprob_start_len`
        self.last_update_decode_tokens = 0
        self.logprob_start_len = 10**9

Lianmin Zheng's avatar
Lianmin Zheng committed
493
    def __repr__(self):
494
495
496
497
        return (
            f"rid(n={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
498
499


500
501
502
bid = 0


503
@dataclasses.dataclass
504
class ScheduleBatch:
505
    """Store all information of a batch on the scheduler."""
506

507
    # Request, memory pool, and cache
508
    reqs: List[Req]
509
510
511
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
512

513
    # Batch configs
514
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
515
    forward_mode: ForwardMode = None
516
517
518
    enable_overlap: bool = False

    # Sampling info
519
    sampling_info: SamplingBatchInfo = None
520
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
521

522
    # Batched arguments to model runner
523
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
524
    input_embeds: torch.Tensor = None
525
526
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
527
    # The output locations of the KV cache
528
    out_cache_loc: torch.Tensor = None
529
530
    output_ids: torch.Tensor = None

531
532
533
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
534
535
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
536
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
537

538
    # For processing logprobs
539
    return_logprob: bool = False
540
541
542
543
544
545
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
546
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
547
    extend_logprob_start_lens: List[int] = None
548

549
550
551
552
553
554
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

555
556
557
    # Stream
    has_stream: bool = False

558
559
    # Has grammar
    has_grammar: bool = False
560

561
562
563
    # device
    device: str = "cuda"

564
    @classmethod
565
566
    def init_new(
        cls,
567
        reqs: List[Req],
568
569
570
571
572
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
573
    ):
574
575
576
577
578
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
579
            model_config=model_config,
580
            enable_overlap=enable_overlap,
581
582
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
583
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
584
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
585
586
        )

587
    def batch_size(self):
588
        return len(self.reqs)
589

Lianmin Zheng's avatar
Lianmin Zheng committed
590
591
592
    def is_empty(self):
        return len(self.reqs) == 0

593
    def alloc_req_slots(self, num_reqs: int):
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
611
612
613
614
615
616
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
617
618
619
620
621
622
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
653
                # NOTE: the encoder part should be considered as a whole
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
679
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
680
681
682
683
684
685
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
686
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
687
688
689
690
691
692
693
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

694
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
695
696
        self.forward_mode = ForwardMode.EXTEND

697
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
698
        reqs = self.reqs
699
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
700
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
701
        seq_lens = []
702
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
703

704
        # Allocate memory
705
        req_pool_indices = self.alloc_req_slots(bs)
706
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
707

Rin Intachuen's avatar
Rin Intachuen committed
708
709
710
        input_embeds = []

        pt = 0
711
        for i, req in enumerate(reqs):
712
713
714
715
716
717
718
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

719
            req.req_pool_idx = req_pool_indices[i]
720
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
721
            seq_lens.append(seq_len)
722
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
723

724
            if pre_len > 0:
725
726
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
727
                )
728

Rin Intachuen's avatar
Rin Intachuen committed
729
730
731
732
733
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

734
735
736
737
738
739
740
741
742
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
743
            req.is_retracted = False
744
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
745
746

        # Set fields
747
748
749
750
751
752
753
754
755
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
756
757
758
759
760
761
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
762
        self.out_cache_loc = out_cache_loc
763
764

        self.seq_lens_sum = sum(seq_lens)
765
766
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
767
        self.extend_num_tokens = extend_num_tokens
768
769
770
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
771

772
773
774
775
776
777
778
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
797
798
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

799
800
801
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

802
        # Build sampling info
803
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
804
805
            self,
            self.model_config.vocab_size,
806
            enable_overlap_schedule=self.enable_overlap,
807
        )
808

809
    def mix_with_running(self, running_batch: "ScheduleBatch"):
810
        self.forward_mode = ForwardMode.MIXED
811
        running_bs = running_batch.batch_size()
812
813
814
815
816

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

817
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
818
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
819

820
        self.merge_batch(running_batch)
821
822
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
823

824
825
826
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

827
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
828
        self.prefix_lens.extend(
829
            [
830
                len(r.origin_input_ids) + len(r.output_ids) + delta
831
832
833
                for r in running_batch.reqs
            ]
        )
834
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
835
836
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
837
        self.extend_logprob_start_lens.extend([0] * running_bs)
838

839
    def check_decode_mem(self):
840
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
841
        if self.token_to_kv_pool.available_size() >= bs:
842
843
            return True

Mingyi's avatar
Mingyi committed
844
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
845

846
847
848
849
850
851
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
852
        """Retract the decoding requests when there is not enough memory."""
853
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
854
855

        # TODO(lsyin): improve retraction policy for radix cache
856
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
857
858
859
860
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
861
862
863
864
            reverse=True,
        )

        retracted_reqs = []
865
        seq_lens_cpu = self.seq_lens.cpu().numpy()
866
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
867
868
869
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
870
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
871
872
873
874
875
876
877
878
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

879
            first_iter = False
880
881
882
883
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

884
885
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
886
887
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
888
                ]
889
                self.token_to_kv_pool.free(token_indices)
890
                self.req_to_token_pool.free(req.req_pool_idx)
891
892
893
894
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
895
896
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
897
                ]
898
                self.token_to_kv_pool.free(token_indices)
899
                self.req_to_token_pool.free(req.req_pool_idx)
900
901
902
903
904
905
906
907
908
909
910

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
911
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
912

913
        self.filter_batch(keep_indices=sorted_indices)
914

Liangsheng Yin's avatar
Liangsheng Yin committed
915
916
917
918
919
920
921
922
923
924
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
925

926
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
927
        jump_forward_reqs = []
928
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
929
930

        for i, req in enumerate(self.reqs):
931
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
934
935
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
936
937
938
939
940
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
941
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
942
943
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
944
945
                        continue

sglang's avatar
sglang committed
946
947
948
                    (
                        jump_forward_str,
                        next_state,
949
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
950

Lianmin Zheng's avatar
Lianmin Zheng committed
951
952
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
953
954
955
956
957
958
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
959

960
961
962
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
963
                    # insert the old request into tree_cache
964
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
965

Liangsheng Yin's avatar
Liangsheng Yin committed
966
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
967
                    if req.image_inputs is not None:
968
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
969
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
970
971
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
972
                    jump_forward_reqs.append(req)
973
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
974

975
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
976

Liangsheng Yin's avatar
Liangsheng Yin committed
977
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
978

979
980
981
982
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
983
984
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
985
986
987
988
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
989
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
990
991
        self.extend_num_tokens = 0

992
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
993
994
        self.forward_mode = ForwardMode.DECODE

995
996
        self.input_ids = self.output_ids
        self.output_ids = None
997
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
998
999

        # Alloc mem
1000
        bs = len(self.reqs)
1001
        self.out_cache_loc = self.alloc_token_slots(bs)
1002

1003
1004
1005
1006
1007
1008
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1009
        if self.enable_overlap:
1010
1011
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1012
                (self.req_pool_indices, locs), self.out_cache_loc
1013
1014
1015
1016
1017
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1018
                (self.req_pool_indices, locs), self.out_cache_loc
1019
1020
            )
            self.seq_lens.add_(1)
1021
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1022

1023
1024
    def filter_batch(
        self,
1025
        being_chunked_req: Optional[Req] = None,
1026
1027
1028
1029
1030
1031
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1032
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1033
1034
1035
            ]

        if keep_indices is None or len(keep_indices) == 0:
1036
1037
1038
1039
            # Filter out all requests
            self.reqs = []
            return

1040
        if len(keep_indices) == len(self.reqs):
1041
1042
1043
            # No need to filter
            return

1044
1045
1046
1047
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1048
        self.reqs = [self.reqs[i] for i in keep_indices]
1049
1050
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
1051
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
        self.req_pool_indices = self.req_pool_indices[new_indices]
1053
        self.seq_lens = self.seq_lens[new_indices]
1054
        self.out_cache_loc = None
1055
        self.seq_lens_sum = self.seq_lens.sum().item()
1056
        self.output_ids = self.output_ids[new_indices]
1057
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1058
        if self.return_logprob:
1059
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1060
1061
        else:
            self.top_logprobs_nums = None
1062

1063
        self.has_stream = any(req.stream for req in self.reqs)
1064
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1065

1066
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1067

1068
    def merge_batch(self, other: "ScheduleBatch"):
1069
1070
1071
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1072
        self.sampling_info.merge_batch(other.sampling_info)
1073

1074
1075
1076
1077
1078
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1079
1080
1081
1082
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1083
        self.out_cache_loc = None
1084
        self.seq_lens_sum += other.seq_lens_sum
1085
1086
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1087
1088
1089
1090
1091
1092
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1093
        self.reqs.extend(other.reqs)
1094

1095
1096
1097
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1098
1099

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1100
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1101
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1102
1103
1104
1105
1106
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1107
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1108
1109
1110
1111
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1112

1113
1114
1115
        global bid
        bid += 1

1116
        return ModelWorkerBatch(
1117
            bid=bid,
1118
1119
1120
1121
1122
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1123
            seq_lens_sum=self.seq_lens_sum,
1124
1125
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1126
            global_num_tokens=self.global_num_tokens,
1127
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1128
            extend_num_tokens=self.extend_num_tokens,
1129
1130
1131
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1132
1133
1134
1135
1136
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1137
            lora_paths=[req.lora_path for req in self.reqs],
1138
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1139
            input_embeds=self.input_embeds,
1140
1141
        )

1142
    def copy(self):
1143
        # Only contain fields that will be used by process_batch_result
1144
1145
        return ScheduleBatch(
            reqs=self.reqs,
1146
            model_config=self.model_config,
1147
            forward_mode=self.forward_mode,
1148
1149
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1150
            decoding_reqs=self.decoding_reqs,
1151
1152
1153
1154
1155
1156
1157
1158
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1159

1160
@dataclasses.dataclass
1161
class ModelWorkerBatch:
1162
1163
    # The batch id
    bid: int
1164
1165
1166
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1167
    input_ids: torch.Tensor
1168
1169
1170
1171
1172
1173
1174
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1175
1176
1177
    # The sum of all sequence lengths
    seq_lens_sum: int

1178
1179
1180
1181
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1182
1183
    # For DP attention
    global_num_tokens: Optional[List[int]]
1184
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1185

1186
    # For extend
1187
    extend_num_tokens: Optional[int]
1188
1189
1190
1191
1192
1193
1194
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1195
1196
1197
1198
1199
1200
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1201
1202
1203
1204
1205
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1206

Rin Intachuen's avatar
Rin Intachuen committed
1207
1208
1209
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )