schedule_batch.py 48 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
37
import torch
38
39
import triton
import triton.language as tl
40

Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.global_config import global_config
42
from sglang.srt.configs.model_config import ModelConfig
43
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Lianmin Zheng's avatar
Lianmin Zheng committed
47
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
from sglang.srt.sampling.sampling_params import SamplingParams
50
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
53
54
if TYPE_CHECKING:
    from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
55
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
56

57
58
# Put some global args for easy access
global_server_args_dict = {
59
60
61
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
62
    "disable_mla": ServerArgs.disable_mla,
63
    "torchao_config": ServerArgs.torchao_config,
64
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
65
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
66
    "enable_ep_moe": ServerArgs.enable_ep_moe,
67
    "device": ServerArgs.device,
68
    "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
69
70
}

Ying Sheng's avatar
Ying Sheng committed
71
72
73
logger = logging.getLogger(__name__)


74
75
76
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
77

78
    def to_json(self):
79
        raise NotImplementedError()
80
81
82


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
83
    def __init__(self, matched: Union[int, List[int]]):
84
85
86
        super().__init__()
        self.matched = matched

87
88
89
90
91
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
92
93


94
95
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
96
        super().__init__()
97
        self.matched = matched
98

99
100
101
102
103
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
104
105


106
107
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
108
        super().__init__()
109
        self.length = length
110

111
112
113
114
115
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
116
117
118


class FINISH_ABORT(BaseFinishReason):
119
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
120
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
121
        self.message = message
122
123
        self.status_code = status_code
        self.err_type = err_type
124

125
126
127
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
128
            "message": self.message,
129
130
            "status_code": self.status_code,
            "err_type": self.err_type,
131
        }
132

Lianmin Zheng's avatar
Lianmin Zheng committed
133

134
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
135
class ImageInputs:
136
137
    """The image related inputs."""

138
    pixel_values: Union[torch.Tensor, np.array]
139
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
140
141
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
142
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
145
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
146

147
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
148
149
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
150

Yineng Zhang's avatar
Yineng Zhang committed
151
152
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
153
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
154

Mick's avatar
Mick committed
155
156
157
158
159
160
161
162
163
    # MiniCPMV related
    # All the images in the batch should share the same special image
    # bound token ids.
    im_start_id: Optional[torch.Tensor] = None
    im_end_id: Optional[torch.Tensor] = None
    slice_start_id: Optional[torch.Tensor] = None
    slice_end_id: Optional[torch.Tensor] = None
    tgt_sizes: Optional[list] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
164
    @staticmethod
165
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
166
167
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
168
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
169
        )
170
171
172

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
173
174
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
175
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
176
177
178
179
180
181
182

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
Mick's avatar
Mick committed
183
184
185
186
187
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
188
189
190
191
192
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
193
194
        return ret

195
    def merge(self, other):
196
197
198
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

199
200
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
201
202
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
203
204
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
205
206
207
208

        optional_args = [
            "image_sizes",
            "image_offsets",
209
            "image_pad_len",
210
211
212
213
214
215
216
217
218
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
219

Lianmin Zheng's avatar
Lianmin Zheng committed
220
class Req:
221
    """The input and output status of a request."""
222

223
224
225
226
227
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
228
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
231
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
        stream: bool = False,
232
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
233
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
234
        input_embeds: Optional[List[List[float]]] = None,
235
        session_id: Optional[str] = None,
236
        custom_logit_processor: Optional[str] = None,
237
        eos_token_ids: Optional[Set[int]] = None,
238
    ):
239
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
240
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
241
        self.origin_input_text = origin_input_text
242
243
244
245
246
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
247
        self.origin_input_ids = origin_input_ids
248
249
250
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
251
        self.fill_ids = None
252
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
253
        self.input_embeds = input_embeds
254

Lianmin Zheng's avatar
Lianmin Zheng committed
255
        # Sampling info
256
        self.sampling_params = sampling_params
257
        self.custom_logit_processor = custom_logit_processor
Liangsheng Yin's avatar
Liangsheng Yin committed
258

259
        # Memory pool info
260
261
        self.req_pool_idx = None

262
263
264
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
265
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
266
        self.stream = stream
267
        self.eos_token_ids = eos_token_ids
268

269
        # For incremental decoding
270
271
272
273
274
275
276
277
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
278
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
279
280
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
281
        self.decoded_text = ""
282

283
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
284
        self.image_inputs: Optional[ImageInputs] = None
285

286
287
        # Prefix info
        self.prefix_indices = []
288
        # Tokens to run prefill. input_tokens - shared_prefix_tokens.
289
        # Updated if chunked.
290
        self.extend_input_len = 0
291
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
292
293

        # Chunked prefill
294
        self.is_being_chunked = 0
295

296
297
298
        # For retraction
        self.is_retracted = False

299
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
300
        self.return_logprob = return_logprob
301
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
302
        self.top_logprobs_num = top_logprobs_num
303

304
        # Logprobs (return values)
305
306
307
308
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
312
313
314
315
316
317
318

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
            ) = self.output_top_logprobs_idx = None
319
        self.hidden_states = []
320
321

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
322
323
324
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
325
326
327
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

328
        # Embedding (return values)
329
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
330

331
        # Constrained decoding
332
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
333

334
        # The number of cached tokens that were already cached in the KV cache
335
        self.cached_tokens = 0
336
        self.already_computed = 0
337

338
339
340
341
342
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

343
    def extend_image_inputs(self, image_inputs):
344
345
346
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
347
            self.image_inputs.merge(image_inputs)
348

349
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        # Whether request reached finished condition
351
352
        return self.finished_reason is not None

353
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
354
        self.fill_ids = self.origin_input_ids + self.output_ids
355
        if tree_cache is not None:
356
            # tree cache is None if the prefix is not computed with tree cache.
357
358
359
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
360
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
361

362
    def adjust_max_prefix_ids(self):
363
364
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
365
366
367
368

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
369
370
371
372
373

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

374
        if self.return_logprob:
375
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
376

377
        max_prefix_len = max(max_prefix_len, 0)
378
        return self.fill_ids[:max_prefix_len]
379

Liangsheng Yin's avatar
Liangsheng Yin committed
380
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
381
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
382
383
384
385
386
387
388
389
390
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
391
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
392

393
    def get_next_inc_detokenization(self):
394
395
        if self.tokenizer is None:
            return False, ""
396
397
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
398
399
400
401
402

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
403
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
404
405
406
407
408
409
410
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
411
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
412
413

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
414

415
    def check_finished(self):
416
        if self.finished():
417
418
            return

419
420
421
422
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
423
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
424
425
426
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
427
428
            return

429
        last_token_id = self.output_ids[-1]
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
448

449
        # Check stop strings
450
451
452
453
454
455
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
456
                if stop_str in tail_str or stop_str in self.decoded_text:
457
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
458
459
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
460
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
461
462
463
464
465
466
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
467
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
468
        all_ids = self.tokenizer.encode(all_text)
469
        if not all_ids:
havetc's avatar
havetc committed
470
            logger.warning("Encoded all_text resulted in empty all_ids")
471
472
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
473
        prompt_tokens = len(self.origin_input_ids_unpadded)
474
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
475
            logger.warning("prompt_tokens is larger than encoded all_ids")
476
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
477
478
479

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
480
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
499

500
501
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
502
503
504
505

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
506
507
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
508
509
510
                    k = k + 1
                else:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
511
512
513
514
            self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
            self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
            self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
            self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
515
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
516
            self.last_update_decode_tokens = len(self.output_ids) - k
517

Liangsheng Yin's avatar
Liangsheng Yin committed
518
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
519

520
521
522
523
524
525
526
527
528
529
530
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True

        # For incremental logprobs
        # TODO: Fix the `logprob_start_len`
        self.last_update_decode_tokens = 0
        self.logprob_start_len = 10**9

Lianmin Zheng's avatar
Lianmin Zheng committed
531
    def __repr__(self):
532
533
534
535
        return (
            f"rid(n={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
536
537


538
539
540
bid = 0


541
@dataclasses.dataclass
542
class ScheduleBatch:
543
    """Store all information of a batch on the scheduler."""
544

545
    # Request, memory pool, and cache
546
    reqs: List[Req]
547
548
549
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
550

551
    # Batch configs
552
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
553
    forward_mode: ForwardMode = None
554
555
556
    enable_overlap: bool = False

    # Sampling info
557
    sampling_info: SamplingBatchInfo = None
558
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
559

560
    # Batched arguments to model runner
561
562
563
564
    input_ids: torch.Tensor = None  # shape: [b], int32
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
    req_pool_indices: torch.Tensor = None  # shape: [b], int32
    seq_lens: torch.Tensor = None  # shape: [b], int64
565
    # The output locations of the KV cache
566
567
    out_cache_loc: torch.Tensor = None  # shape: [b], int32
    output_ids: torch.Tensor = None  # shape: [b], int32
568

569
570
571
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
572
573
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
574
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
575

576
    # For processing logprobs
577
    return_logprob: bool = False
578
579
580
581
582
583
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
584
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
585
    extend_logprob_start_lens: List[int] = None
586

587
588
589
590
591
592
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

593
594
595
    # Stream
    has_stream: bool = False

596
597
    # Has grammar
    has_grammar: bool = False
598

599
    # Device
600
601
    device: str = "cuda"

602
    # Speculative decoding
603
    spec_algorithm: SpeculativeAlgorithm = None
604
605
    spec_info: Optional[SpecInfo] = None

606
607
608
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

609
610
611
    # Return hidden states
    return_hidden_states: bool = False

612
    @classmethod
613
614
    def init_new(
        cls,
615
        reqs: List[Req],
616
617
618
619
620
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
621
        spec_algorithm: SpeculativeAlgorithm,
622
        enable_custom_logit_processor: bool,
623
        return_hidden_states: bool = False,
624
    ):
625
626
627
628
629
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
630
            model_config=model_config,
631
            enable_overlap=enable_overlap,
632
633
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
634
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
635
            device=req_to_token_pool.device,
636
            spec_algorithm=spec_algorithm,
637
            enable_custom_logit_processor=enable_custom_logit_processor,
638
            return_hidden_states=return_hidden_states,
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
        )

641
    def batch_size(self):
642
        return len(self.reqs)
643

Lianmin Zheng's avatar
Lianmin Zheng committed
644
645
646
    def is_empty(self):
        return len(self.reqs) == 0

647
    def alloc_req_slots(self, num_reqs: int):
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
665
666
667
668
669
670
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
671
672
673
674
675
676
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

694
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
695
696
697
698
699
700
701
702
703
704
705
706
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
707
                # NOTE: the encoder part should be considered as a whole
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
728
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
729
730
731
732
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
733
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
734
735
736
737
738
739
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
740
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
741
742
743
744
745
746
747
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

748
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
749
750
        self.forward_mode = ForwardMode.EXTEND

751
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
752
        reqs = self.reqs
753
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
754
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
755
        seq_lens = []
756
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
757

758
        # Allocate memory
759
        req_pool_indices = self.alloc_req_slots(bs)
760
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
761

Rin Intachuen's avatar
Rin Intachuen committed
762
763
764
        input_embeds = []

        pt = 0
765
        for i, req in enumerate(reqs):
766
            req.req_pool_idx = req_pool_indices[i]
767
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
768
            seq_lens.append(seq_len)
769
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
770

771
            if pre_len > 0:
772
773
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
774
                )
775

Rin Intachuen's avatar
Rin Intachuen committed
776
777
778
779
780
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

781
782
783
784
785
786
787
788
789
790
791
            if req.return_logprob:
                # Compute the relative logprob_start_len in an extend batch
                if req.logprob_start_len >= pre_len:
                    extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len, req.extend_input_len - 1
                    )
                else:
                    raise RuntimeError(
                        f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
                    )
                req.extend_logprob_start_len = extend_logprob_start_len
792

793
794
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
795
            req.is_retracted = False
796
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
797
798

        # Set fields
799
800
801
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
802
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
803
804
            self.device, non_blocking=True
        )
805
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
806
807
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
808
809
810
811
812
813
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
814
        self.out_cache_loc = out_cache_loc
815
816

        self.seq_lens_sum = sum(seq_lens)
817
818
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
819
        self.extend_num_tokens = extend_num_tokens
820
821
822
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
823

824
825
826
827
828
829
830
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
849
850
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

851
852
853
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

854
        # Build sampling info
855
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
856
857
            self,
            self.model_config.vocab_size,
858
            enable_overlap_schedule=self.enable_overlap,
859
        )
860

861
    def mix_with_running(self, running_batch: "ScheduleBatch"):
862
        self.forward_mode = ForwardMode.MIXED
863
        running_bs = running_batch.batch_size()
864
865
866
867
868

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

869
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
870
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
871

872
        self.merge_batch(running_batch)
873
874
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
875

876
877
878
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

879
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
880
        self.prefix_lens.extend(
881
            [
882
                len(r.origin_input_ids) + len(r.output_ids) + delta
883
884
885
                for r in running_batch.reqs
            ]
        )
886
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
887
888
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
889
        self.extend_logprob_start_lens.extend([0] * running_bs)
890

891
892
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
Ying Sheng's avatar
Ying Sheng committed
893
        if self.token_to_kv_pool.available_size() >= bs:
894
895
            return True

Mingyi's avatar
Mingyi committed
896
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
897

898
899
900
901
902
903
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
904
        """Retract the decoding requests when there is not enough memory."""
905
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
906
907

        # TODO(lsyin): improve retraction policy for radix cache
908
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
909
910
911
912
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
913
914
915
916
            reverse=True,
        )

        retracted_reqs = []
917
        seq_lens_cpu = self.seq_lens.cpu().numpy()
918
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
919
920
921
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
922
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
923
924
925
926
927
928
929
930
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

931
            first_iter = False
932
933
934
935
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

936
937
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
938
939
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
940
                ]
941
                self.token_to_kv_pool.free(token_indices)
942
                self.req_to_token_pool.free(req.req_pool_idx)
943
944
945
946
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
947
948
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
949
                ]
950
                self.token_to_kv_pool.free(token_indices)
951
                self.req_to_token_pool.free(req.req_pool_idx)
952
953
954
955
956
957
958
959
960
961
962

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
963
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
964

965
        self.filter_batch(keep_indices=sorted_indices)
966

Liangsheng Yin's avatar
Liangsheng Yin committed
967
968
969
970
971
972
973
974
975
976
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
977

978
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
979
        jump_forward_reqs = []
980
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
981
982

        for i, req in enumerate(self.reqs):
983
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
984
985
986
987
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
988
989
990
991
992
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
993
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
994
995
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
996
997
                        continue

sglang's avatar
sglang committed
998
999
1000
                    (
                        jump_forward_str,
                        next_state,
1001
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
1002

Lianmin Zheng's avatar
Lianmin Zheng committed
1003
1004
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
1005
1006
1007
1008
1009
1010
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
1011

1012
1013
1014
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1015
                    # insert the old request into tree_cache
1016
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
1017

Liangsheng Yin's avatar
Liangsheng Yin committed
1018
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
1019
                    if req.image_inputs is not None:
1020
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
1021
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
1022
1023
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1024
                    jump_forward_reqs.append(req)
1025
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1026

1027
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
1028

Liangsheng Yin's avatar
Liangsheng Yin committed
1029
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
1030

1031
1032
1033
1034
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1035
1036
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
1037
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1038
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1039
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1040
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1041
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1042
        self.extend_num_tokens = 0
1043
1044
1045
1046
1047
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
            enable_overlap_schedule=self.enable_overlap,
        )
Ke Bao's avatar
Ke Bao committed
1048

1049
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1050
        self.forward_mode = ForwardMode.DECODE
1051
1052
        if self.spec_algorithm.is_eagle():
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1053

1054
1055
        self.input_ids = self.output_ids
        self.output_ids = None
1056
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1057
1058

        # Alloc mem
1059
        bs = len(self.reqs)
1060
        self.out_cache_loc = self.alloc_token_slots(bs)
1061

1062
1063
1064
1065
1066
1067
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1068
        if self.enable_overlap:
1069
1070
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1071
                (self.req_pool_indices, locs), self.out_cache_loc
1072
1073
1074
1075
1076
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1077
                (self.req_pool_indices, locs), self.out_cache_loc
1078
1079
            )
            self.seq_lens.add_(1)
1080
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1081

1082
1083
    def filter_batch(
        self,
1084
        being_chunked_req: Optional[Req] = None,
1085
1086
1087
1088
1089
1090
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1091
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1092
1093
1094
            ]

        if keep_indices is None or len(keep_indices) == 0:
1095
1096
1097
1098
            # Filter out all requests
            self.reqs = []
            return

1099
        if len(keep_indices) == len(self.reqs):
1100
1101
1102
            # No need to filter
            return

1103
1104
1105
1106
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1107
        self.reqs = [self.reqs[i] for i in keep_indices]
1108
        new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1109
            self.device, non_blocking=True
1110
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1111
        self.req_pool_indices = self.req_pool_indices[new_indices]
1112
        self.seq_lens = self.seq_lens[new_indices]
1113
        self.out_cache_loc = None
1114
        self.seq_lens_sum = self.seq_lens.sum().item()
1115
        self.output_ids = self.output_ids[new_indices]
1116
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1117
        if self.return_logprob:
1118
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1119
1120
        else:
            self.top_logprobs_nums = None
1121

1122
        self.has_stream = any(req.stream for req in self.reqs)
1123
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1124

1125
        self.sampling_info.filter_batch(keep_indices, new_indices)
1126
1127
        if self.spec_info:
            self.spec_info.filter_batch(new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1128

1129
    def merge_batch(self, other: "ScheduleBatch"):
1130
1131
1132
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1133
        self.sampling_info.merge_batch(other.sampling_info)
1134

1135
1136
1137
1138
1139
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1140
1141
1142
1143
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1144
        self.out_cache_loc = None
1145
        self.seq_lens_sum += other.seq_lens_sum
1146
1147
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1148
1149
1150
1151
1152
1153
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1154
        self.reqs.extend(other.reqs)
1155

1156
1157
1158
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1159

1160
1161
1162
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1163
    def get_model_worker_batch(self):
1164
        if self.forward_mode.is_decode_or_idle():
1165
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1166
1167
1168
1169
1170
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1171
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1172
1173
1174
1175
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1176

1177
1178
        global bid
        bid += 1
1179
        return ModelWorkerBatch(
1180
            bid=bid,
1181
1182
1183
1184
1185
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1186
            seq_lens_sum=self.seq_lens_sum,
1187
1188
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1189
            global_num_tokens=self.global_num_tokens,
1190
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1191
            extend_num_tokens=self.extend_num_tokens,
1192
1193
1194
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1195
1196
1197
1198
1199
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1200
            lora_paths=[req.lora_path for req in self.reqs],
1201
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1202
            input_embeds=self.input_embeds,
1203
1204
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1205
            capture_hidden_mode=(
1206
1207
1208
1209
1210
1211
1212
1213
1214
                CaptureHiddenMode.FULL
                if self.return_hidden_states
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1215
            ),
1216
1217
        )

1218
    def copy(self):
1219
        # Only contain fields that will be used by process_batch_result
1220
1221
        return ScheduleBatch(
            reqs=self.reqs,
1222
            model_config=self.model_config,
1223
            forward_mode=self.forward_mode,
1224
1225
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1226
            decoding_reqs=self.decoding_reqs,
1227
            spec_algorithm=self.spec_algorithm,
1228
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1229
1230
1231
1232
1233
1234
1235
1236
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1237

1238
@dataclasses.dataclass
1239
class ModelWorkerBatch:
1240
1241
    # The batch id
    bid: int
1242
1243
1244
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1245
    input_ids: torch.Tensor
1246
1247
1248
1249
1250
1251
1252
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1253
1254
1255
    # The sum of all sequence lengths
    seq_lens_sum: int

1256
1257
1258
1259
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1260
1261
    # For DP attention
    global_num_tokens: Optional[List[int]]
1262
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1263

1264
    # For extend
1265
    extend_num_tokens: Optional[int]
1266
1267
1268
1269
1270
1271
1272
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1273
1274
1275
1276
1277
1278
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1279
1280
1281
1282
1283
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1284

Rin Intachuen's avatar
Rin Intachuen committed
1285
1286
1287
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1288
    # Speculative decoding
1289
    spec_algorithm: SpeculativeAlgorithm = None
1290
    spec_info: Optional[SpecInfo] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
    capture_hidden_mode: CaptureHiddenMode = None
1292

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )