"docs/vscode:/vscode.git/clone" did not exist on "1f76fc6e3f6f95e823e350330e575e573f4bb3ee"
schedule_batch.py 48 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
37
import torch
38
39
import triton
import triton.language as tl
40

Liangsheng Yin's avatar
Liangsheng Yin committed
41
from sglang.global_config import global_config
42
from sglang.srt.configs.model_config import ModelConfig
43
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Lianmin Zheng's avatar
Lianmin Zheng committed
47
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
from sglang.srt.sampling.sampling_params import SamplingParams
50
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
53
54
if TYPE_CHECKING:
    from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm

Liangsheng Yin's avatar
Liangsheng Yin committed
55
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
56

57
58
# Put some global args for easy access
global_server_args_dict = {
59
60
61
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
62
    "disable_mla": ServerArgs.disable_mla,
63
    "torchao_config": ServerArgs.torchao_config,
64
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
65
    "enable_dp_attention": ServerArgs.enable_dp_attention,
xiaobochen's avatar
xiaobochen committed
66
    "enable_ep_moe": ServerArgs.enable_ep_moe,
67
    "device": ServerArgs.device,
68
69
}

Ying Sheng's avatar
Ying Sheng committed
70
71
72
logger = logging.getLogger(__name__)


73
74
75
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
76

77
    def to_json(self):
78
        raise NotImplementedError()
79
80
81


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
82
    def __init__(self, matched: Union[int, List[int]]):
83
84
85
        super().__init__()
        self.matched = matched

86
87
88
89
90
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
91
92


93
94
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
95
        super().__init__()
96
        self.matched = matched
97

98
99
100
101
102
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
103
104


105
106
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
107
        super().__init__()
108
        self.length = length
109

110
111
112
113
114
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
115
116
117


class FINISH_ABORT(BaseFinishReason):
118
    def __init__(self, message="Unknown error", status_code=None, err_type=None):
119
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        self.message = message
121
122
        self.status_code = status_code
        self.err_type = err_type
123

124
125
126
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
127
            "message": self.message,
128
129
            "status_code": self.status_code,
            "err_type": self.err_type,
130
        }
131

Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
134
class ImageInputs:
135
136
    """The image related inputs."""

137
    pixel_values: Union[torch.Tensor, np.array]
138
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
139
140
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
141
    image_pad_len: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
144
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
145

146
    # Llava related
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
149

Yineng Zhang's avatar
Yineng Zhang committed
150
151
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
152
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
153

Mick's avatar
Mick committed
154
155
156
157
158
159
160
161
162
    # MiniCPMV related
    # All the images in the batch should share the same special image
    # bound token ids.
    im_start_id: Optional[torch.Tensor] = None
    im_end_id: Optional[torch.Tensor] = None
    slice_start_id: Optional[torch.Tensor] = None
    slice_end_id: Optional[torch.Tensor] = None
    tgt_sizes: Optional[list] = None

Liangsheng Yin's avatar
Liangsheng Yin committed
163
    @staticmethod
164
    def from_dict(obj: dict):
Liangsheng Yin's avatar
Liangsheng Yin committed
165
166
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
167
            image_hashes=obj["image_hashes"],
Liangsheng Yin's avatar
Liangsheng Yin committed
168
        )
169
170
171

        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
172
173
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
174
        ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
175
176
177
178
179
180
181

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
Mick's avatar
Mick committed
182
183
184
185
186
            "im_start_id",
            "im_end_id",
            "slice_start_id",
            "slice_end_id",
            "tgt_sizes",
187
188
189
190
191
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
192
193
        return ret

194
    def merge(self, other):
195
196
197
        assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
        self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])

198
199
        # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
        # Please note that if the `input_ids` is later used in the model forward,
200
201
        # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
        # errors in cuda kernels. See also llava.py for example.
202
203
        self.image_hashes += other.image_hashes
        self.pad_values = [x % (1 << 30) for x in self.image_hashes]
204
205
206
207

        optional_args = [
            "image_sizes",
            "image_offsets",
208
            "image_pad_len",
209
210
211
212
213
214
215
216
217
            # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if getattr(self, arg, None) is not None:
                setattr(self, arg, getattr(self, arg) + getattr(other, arg))

Liangsheng Yin's avatar
Liangsheng Yin committed
218

Lianmin Zheng's avatar
Lianmin Zheng committed
219
class Req:
220
    """The input and output status of a request."""
221

222
223
224
225
226
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
227
        sampling_params: SamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
        return_logprob: bool = False,
        top_logprobs_num: int = 0,
        stream: bool = False,
231
        origin_input_ids_unpadded: Optional[Tuple[int]] = None,
232
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
233
        input_embeds: Optional[List[List[float]]] = None,
234
        session_id: Optional[str] = None,
235
        custom_logit_processor: Optional[str] = None,
236
        eos_token_ids: Optional[Set[int]] = None,
237
    ):
238
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
239
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
240
        self.origin_input_text = origin_input_text
241
242
243
244
245
        self.origin_input_ids_unpadded = (
            origin_input_ids_unpadded
            if origin_input_ids_unpadded
            else origin_input_ids  # Before image padding
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
246
        self.origin_input_ids = origin_input_ids
247
248
249
        # Each decode stage's output ids
        self.output_ids = []
        # fill_ids = origin_input_ids + output_ids. Updated if chunked.
250
        self.fill_ids = None
251
        self.session_id = session_id
Lianmin Zheng's avatar
Lianmin Zheng committed
252
        self.input_embeds = input_embeds
253

Lianmin Zheng's avatar
Lianmin Zheng committed
254
        # Sampling info
255
        self.sampling_params = sampling_params
256
        self.custom_logit_processor = custom_logit_processor
Liangsheng Yin's avatar
Liangsheng Yin committed
257

258
        # Memory pool info
259
260
        self.req_pool_idx = None

261
262
263
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
264
        self.to_abort = False
Lianmin Zheng's avatar
Lianmin Zheng committed
265
        self.stream = stream
266
        self.eos_token_ids = eos_token_ids
267

268
        # For incremental decoding
269
270
271
272
273
274
275
276
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
277
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
278
279
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
Lianmin Zheng's avatar
Lianmin Zheng committed
280
        self.decoded_text = ""
281

282
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
283
        self.image_inputs: Optional[ImageInputs] = None
284

285
286
        # Prefix info
        self.prefix_indices = []
287
        # Tokens to run prefill. input_tokens - shared_prefix_tokens.
288
        # Updated if chunked.
289
        self.extend_input_len = 0
290
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
291
292

        # Chunked prefill
293
        self.is_being_chunked = 0
294

295
296
297
        # For retraction
        self.is_retracted = False

298
        # Logprobs (arguments)
Lianmin Zheng's avatar
Lianmin Zheng committed
299
        self.return_logprob = return_logprob
300
        self.logprob_start_len = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
301
        self.top_logprobs_num = top_logprobs_num
302

303
        # Logprobs (return values)
304
305
306
307
        self.input_token_logprobs_val: Optional[List[float]] = None
        self.input_token_logprobs_idx: Optional[List[int]] = None
        self.input_top_logprobs_val: Optional[List[float]] = None
        self.input_top_logprobs_idx: Optional[List[int]] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
308
309
310
311
312
313
314
315
316
317

        if return_logprob:
            self.output_token_logprobs_val = []
            self.output_token_logprobs_idx = []
            self.output_top_logprobs_val = []
            self.output_top_logprobs_idx = []
        else:
            self.output_token_logprobs_val = self.output_token_logprobs_idx = (
                self.output_top_logprobs_val
            ) = self.output_top_logprobs_idx = None
318
        self.hidden_states = []
319
320

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
321
322
323
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
324
325
326
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

327
        # Embedding (return values)
328
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
329

330
        # Constrained decoding
331
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
332

333
        # The number of cached tokens that were already cached in the KV cache
334
        self.cached_tokens = 0
335
        self.already_computed = 0
336

337
338
339
340
341
        # The number of verification forward passes in the speculative decoding.
        # This is used to compute the average acceptance length per request.
        self.spec_verify_ct = 0
        self.lora_path = lora_path

342
    def extend_image_inputs(self, image_inputs):
343
344
345
        if self.image_inputs is None:
            self.image_inputs = image_inputs
        else:
346
            self.image_inputs.merge(image_inputs)
347

348
    def finished(self) -> bool:
Lianmin Zheng's avatar
Lianmin Zheng committed
349
        # Whether request reached finished condition
350
351
        return self.finished_reason is not None

352
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
353
        self.fill_ids = self.origin_input_ids + self.output_ids
354
        if tree_cache is not None:
355
            # tree cache is None if the prefix is not computed with tree cache.
356
357
358
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
359
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
360

361
    def adjust_max_prefix_ids(self):
362
363
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
364
365
366
367

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
368
369
370
371
372

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

373
        if self.return_logprob:
374
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
375

376
        max_prefix_len = max(max_prefix_len, 0)
377
        return self.fill_ids[:max_prefix_len]
378

Liangsheng Yin's avatar
Liangsheng Yin committed
379
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
380
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
381
382
383
384
385
386
387
388
389
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
390
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
391

392
    def get_next_inc_detokenization(self):
393
394
        if self.tokenizer is None:
            return False, ""
395
396
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
397
398
399
400
401

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
402
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
403
404
405
406
407
408
409
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
410
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
411
412

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
413

414
    def check_finished(self):
415
        if self.finished():
416
417
            return

418
419
420
421
        if self.to_abort:
            self.finished_reason = FINISH_ABORT()
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
422
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
423
424
425
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
426
427
            return

428
        last_token_id = self.output_ids[-1]
429

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        if not self.sampling_params.ignore_eos:
            matched_eos = False

            # Check stop token ids
            if self.sampling_params.stop_token_ids:
                matched_eos = last_token_id in self.sampling_params.stop_token_ids
            if self.eos_token_ids:
                matched_eos |= last_token_id in self.eos_token_ids
            if self.tokenizer is not None:
                matched_eos |= last_token_id == self.tokenizer.eos_token_id
                if self.tokenizer.additional_stop_token_ids:
                    matched_eos |= (
                        last_token_id in self.tokenizer.additional_stop_token_ids
                    )
            if matched_eos:
                self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
                return
447

448
        # Check stop strings
449
450
451
452
453
454
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
455
                if stop_str in tail_str or stop_str in self.decoded_text:
456
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
457
458
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
459
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
460
461
462
463
464
465
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
466
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
467
        all_ids = self.tokenizer.encode(all_text)
468
        if not all_ids:
havetc's avatar
havetc committed
469
            logger.warning("Encoded all_text resulted in empty all_ids")
470
471
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
472
        prompt_tokens = len(self.origin_input_ids_unpadded)
473
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
474
            logger.warning("prompt_tokens is larger than encoded all_ids")
475
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
476
477
478

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
479
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
498

499
500
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
501
502
503
504

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
505
506
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
507
508
509
                    k = k + 1
                else:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
512
513
            self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
            self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
            self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
            self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
514
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
515
            self.last_update_decode_tokens = len(self.output_ids) - k
516

Liangsheng Yin's avatar
Liangsheng Yin committed
517
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
518

519
520
521
522
523
524
525
526
527
528
529
    def reset_for_retract(self):
        self.prefix_indices = []
        self.last_node = None
        self.extend_input_len = 0
        self.is_retracted = True

        # For incremental logprobs
        # TODO: Fix the `logprob_start_len`
        self.last_update_decode_tokens = 0
        self.logprob_start_len = 10**9

Lianmin Zheng's avatar
Lianmin Zheng committed
530
    def __repr__(self):
531
532
533
534
        return (
            f"rid(n={self.rid}, "
            f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
535
536


537
538
539
bid = 0


540
@dataclasses.dataclass
541
class ScheduleBatch:
542
    """Store all information of a batch on the scheduler."""
543

544
    # Request, memory pool, and cache
545
    reqs: List[Req]
546
547
548
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
549

550
    # Batch configs
551
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
552
    forward_mode: ForwardMode = None
553
554
555
    enable_overlap: bool = False

    # Sampling info
556
    sampling_info: SamplingBatchInfo = None
557
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
558

559
    # Batched arguments to model runner
560
561
562
563
    input_ids: torch.Tensor = None  # shape: [b], int32
    input_embeds: torch.Tensor = None  # shape: [b, hidden_size], float32
    req_pool_indices: torch.Tensor = None  # shape: [b], int32
    seq_lens: torch.Tensor = None  # shape: [b], int64
564
    # The output locations of the KV cache
565
566
    out_cache_loc: torch.Tensor = None  # shape: [b], int32
    output_ids: torch.Tensor = None  # shape: [b], int32
567

568
569
570
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
571
572
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
573
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
574

575
    # For processing logprobs
576
    return_logprob: bool = False
577
578
579
580
581
582
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
583
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
584
    extend_logprob_start_lens: List[int] = None
585

586
587
588
589
590
591
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

592
593
594
    # Stream
    has_stream: bool = False

595
596
    # Has grammar
    has_grammar: bool = False
597

598
    # Device
599
600
    device: str = "cuda"

601
    # Speculative decoding
602
    spec_algorithm: SpeculativeAlgorithm = None
603
604
    spec_info: Optional[SpecInfo] = None

605
606
607
    # Enable custom logit processor
    enable_custom_logit_processor: bool = False

608
609
610
    # Return hidden states
    return_hidden_states: bool = False

611
    @classmethod
612
613
    def init_new(
        cls,
614
        reqs: List[Req],
615
616
617
618
619
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
620
        spec_algorithm: SpeculativeAlgorithm,
621
        enable_custom_logit_processor: bool,
622
        return_hidden_states: bool = False,
623
    ):
624
625
626
627
628
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
629
            model_config=model_config,
630
            enable_overlap=enable_overlap,
631
632
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
633
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
634
            device=req_to_token_pool.device,
635
            spec_algorithm=spec_algorithm,
636
            enable_custom_logit_processor=enable_custom_logit_processor,
637
            return_hidden_states=return_hidden_states,
Lianmin Zheng's avatar
Lianmin Zheng committed
638
639
        )

640
    def batch_size(self):
641
        return len(self.reqs)
642

Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
645
    def is_empty(self):
        return len(self.reqs) == 0

646
    def alloc_req_slots(self, num_reqs: int):
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
664
665
666
667
668
669
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
670
671
672
673
674
675
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

693
        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
694
695
696
697
698
699
700
701
702
703
704
705
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
706
                # NOTE: the encoder part should be considered as a whole
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
727
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
728
729
730
731
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
732
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
733
734
735
736
737
738
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
739
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
740
741
742
743
744
745
746
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

747
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
748
749
        self.forward_mode = ForwardMode.EXTEND

750
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
751
        reqs = self.reqs
752
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
753
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
754
        seq_lens = []
755
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
756

757
        # Allocate memory
758
        req_pool_indices = self.alloc_req_slots(bs)
759
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
760

Rin Intachuen's avatar
Rin Intachuen committed
761
762
763
        input_embeds = []

        pt = 0
764
        for i, req in enumerate(reqs):
765
            req.req_pool_idx = req_pool_indices[i]
766
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
767
            seq_lens.append(seq_len)
768
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
769

770
            if pre_len > 0:
771
772
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
773
                )
774

Rin Intachuen's avatar
Rin Intachuen committed
775
776
777
778
779
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

780
781
782
783
784
785
786
787
788
789
790
            if req.return_logprob:
                # Compute the relative logprob_start_len in an extend batch
                if req.logprob_start_len >= pre_len:
                    extend_logprob_start_len = min(
                        req.logprob_start_len - pre_len, req.extend_input_len - 1
                    )
                else:
                    raise RuntimeError(
                        f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
                    )
                req.extend_logprob_start_len = extend_logprob_start_len
791

792
793
            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
794
            req.is_retracted = False
795
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
796
797

        # Set fields
798
799
800
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
801
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
802
803
            self.device, non_blocking=True
        )
804
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
805
806
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
807
808
809
810
811
812
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
813
        self.out_cache_loc = out_cache_loc
814
815

        self.seq_lens_sum = sum(seq_lens)
816
817
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
818
        self.extend_num_tokens = extend_num_tokens
819
820
821
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
822

823
824
825
826
827
828
829
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        if global_server_args_dict["attention_backend"] != "torch_native":
            write_req_to_token_pool_triton[(bs,)](
                self.req_to_token_pool.req_to_token,
                self.req_pool_indices,
                pre_lens,
                self.seq_lens,
                extend_lens,
                self.out_cache_loc,
                self.req_to_token_pool.req_to_token.shape[1],
            )
        else:
            pt = 0
            for i in range(bs):
                self.req_to_token_pool.write(
                    (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
                    self.out_cache_loc[pt : pt + self.extend_lens[i]],
                )
                pt += self.extend_lens[i]
848
849
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

850
851
852
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

853
        # Build sampling info
854
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
855
856
            self,
            self.model_config.vocab_size,
857
            enable_overlap_schedule=self.enable_overlap,
858
        )
859

860
    def mix_with_running(self, running_batch: "ScheduleBatch"):
861
        self.forward_mode = ForwardMode.MIXED
862
        running_bs = running_batch.batch_size()
863
864
865
866
867

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

868
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
869
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
870

871
        self.merge_batch(running_batch)
872
873
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
874

875
876
877
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

878
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
879
        self.prefix_lens.extend(
880
            [
881
                len(r.origin_input_ids) + len(r.output_ids) + delta
882
883
884
                for r in running_batch.reqs
            ]
        )
885
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
886
887
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
888
        self.extend_logprob_start_lens.extend([0] * running_bs)
889

890
891
    def check_decode_mem(self, buf_multiplier=1):
        bs = len(self.reqs) * buf_multiplier
Ying Sheng's avatar
Ying Sheng committed
892
        if self.token_to_kv_pool.available_size() >= bs:
893
894
            return True

Mingyi's avatar
Mingyi committed
895
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
896

897
898
899
900
901
902
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
903
        """Retract the decoding requests when there is not enough memory."""
904
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
905
906

        # TODO(lsyin): improve retraction policy for radix cache
907
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
908
909
910
911
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
912
913
914
915
            reverse=True,
        )

        retracted_reqs = []
916
        seq_lens_cpu = self.seq_lens.cpu().numpy()
917
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
918
919
920
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
921
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
922
923
924
925
926
927
928
929
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

930
            first_iter = False
931
932
933
934
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

935
936
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
937
938
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
939
                ]
940
                self.token_to_kv_pool.free(token_indices)
941
                self.req_to_token_pool.free(req.req_pool_idx)
942
943
944
945
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
946
947
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
948
                ]
949
                self.token_to_kv_pool.free(token_indices)
950
                self.req_to_token_pool.free(req.req_pool_idx)
951
952
953
954
955
956
957
958
959
960
961

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
962
            req.reset_for_retract()
Liangsheng Yin's avatar
Liangsheng Yin committed
963

964
        self.filter_batch(keep_indices=sorted_indices)
965

Liangsheng Yin's avatar
Liangsheng Yin committed
966
967
968
969
970
971
972
973
974
975
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
976

977
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
978
        jump_forward_reqs = []
979
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
980
981

        for i, req in enumerate(self.reqs):
982
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
983
984
985
986
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
987
988
989
990
991
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
992
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
993
994
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
995
996
                        continue

sglang's avatar
sglang committed
997
998
999
                    (
                        jump_forward_str,
                        next_state,
1000
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
1001

Lianmin Zheng's avatar
Lianmin Zheng committed
1002
1003
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
1004
1005
1006
1007
1008
1009
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
1010

1011
1012
1013
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1014
                    # insert the old request into tree_cache
1015
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
1016

Liangsheng Yin's avatar
Liangsheng Yin committed
1017
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
1018
                    if req.image_inputs is not None:
1019
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
1020
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
1021
1022
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1023
                    jump_forward_reqs.append(req)
1024
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1025

1026
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
1027

Liangsheng Yin's avatar
Liangsheng Yin committed
1028
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
1029

1030
1031
1032
1033
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
1034
1035
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
1036
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1037
        self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1038
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1039
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1040
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
1041
        self.extend_num_tokens = 0
1042
1043
1044
1045
1046
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
            enable_overlap_schedule=self.enable_overlap,
        )
Ke Bao's avatar
Ke Bao committed
1047

1048
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
1049
        self.forward_mode = ForwardMode.DECODE
1050
1051
        if self.spec_algorithm.is_eagle():
            return
Liangsheng Yin's avatar
Liangsheng Yin committed
1052

1053
1054
        self.input_ids = self.output_ids
        self.output_ids = None
1055
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
1056
1057

        # Alloc mem
1058
        bs = len(self.reqs)
1059
        self.out_cache_loc = self.alloc_token_slots(bs)
1060

1061
1062
1063
1064
1065
1066
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

1067
        if self.enable_overlap:
1068
1069
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
1070
                (self.req_pool_indices, locs), self.out_cache_loc
1071
1072
1073
1074
1075
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
1076
                (self.req_pool_indices, locs), self.out_cache_loc
1077
1078
            )
            self.seq_lens.add_(1)
1079
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
1080

1081
1082
    def filter_batch(
        self,
1083
        being_chunked_req: Optional[Req] = None,
1084
1085
1086
1087
1088
1089
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
1090
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
1091
1092
1093
            ]

        if keep_indices is None or len(keep_indices) == 0:
1094
1095
1096
1097
            # Filter out all requests
            self.reqs = []
            return

1098
        if len(keep_indices) == len(self.reqs):
1099
1100
1101
            # No need to filter
            return

1102
1103
1104
1105
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

1106
        self.reqs = [self.reqs[i] for i in keep_indices]
1107
        new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1108
            self.device, non_blocking=True
1109
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1110
        self.req_pool_indices = self.req_pool_indices[new_indices]
1111
        self.seq_lens = self.seq_lens[new_indices]
1112
        self.out_cache_loc = None
1113
        self.seq_lens_sum = self.seq_lens.sum().item()
1114
        self.output_ids = self.output_ids[new_indices]
1115
        self.return_logprob = any(req.return_logprob for req in self.reqs)
1116
        if self.return_logprob:
1117
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
1118
1119
        else:
            self.top_logprobs_nums = None
1120

1121
        self.has_stream = any(req.stream for req in self.reqs)
1122
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1123

1124
        self.sampling_info.filter_batch(keep_indices, new_indices)
1125
1126
        if self.spec_info:
            self.spec_info.filter_batch(new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1127

1128
    def merge_batch(self, other: "ScheduleBatch"):
1129
1130
1131
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1132
        self.sampling_info.merge_batch(other.sampling_info)
1133

1134
1135
1136
1137
1138
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140
1141
1142
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1143
        self.out_cache_loc = None
1144
        self.seq_lens_sum += other.seq_lens_sum
1145
1146
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1147
1148
1149
1150
1151
1152
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1153
        self.reqs.extend(other.reqs)
1154

1155
1156
1157
        self.return_logprob |= other.return_logprob
        self.has_stream |= other.has_stream
        self.has_grammar |= other.has_grammar
1158

1159
1160
1161
        if self.spec_info:
            self.spec_info.merge_batch(other.spec_info)

1162
    def get_model_worker_batch(self):
1163
        if self.forward_mode.is_decode_or_idle():
1164
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1165
1166
1167
1168
1169
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1170
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1171
1172
1173
1174
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1175

1176
1177
        global bid
        bid += 1
1178
        return ModelWorkerBatch(
1179
            bid=bid,
1180
1181
1182
1183
1184
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1185
            seq_lens_sum=self.seq_lens_sum,
1186
1187
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1188
            global_num_tokens=self.global_num_tokens,
1189
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1190
            extend_num_tokens=self.extend_num_tokens,
1191
1192
1193
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1194
1195
1196
1197
1198
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1199
            lora_paths=[req.lora_path for req in self.reqs],
1200
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1201
            input_embeds=self.input_embeds,
1202
1203
            spec_algorithm=self.spec_algorithm,
            spec_info=self.spec_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
1204
            capture_hidden_mode=(
1205
1206
1207
1208
1209
1210
1211
1212
1213
                CaptureHiddenMode.FULL
                if self.return_hidden_states
                else (
                    getattr(
                        self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
                    )
                    if self.spec_info
                    else CaptureHiddenMode.NULL
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
1214
            ),
1215
1216
        )

1217
    def copy(self):
1218
        # Only contain fields that will be used by process_batch_result
1219
1220
        return ScheduleBatch(
            reqs=self.reqs,
1221
            model_config=self.model_config,
1222
            forward_mode=self.forward_mode,
1223
1224
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1225
            decoding_reqs=self.decoding_reqs,
1226
            spec_algorithm=self.spec_algorithm,
1227
            enable_custom_logit_processor=self.enable_custom_logit_processor,
1228
1229
1230
1231
1232
1233
1234
1235
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1236

1237
@dataclasses.dataclass
1238
class ModelWorkerBatch:
1239
1240
    # The batch id
    bid: int
1241
1242
1243
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1244
    input_ids: torch.Tensor
1245
1246
1247
1248
1249
1250
1251
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1252
1253
1254
    # The sum of all sequence lengths
    seq_lens_sum: int

1255
1256
1257
1258
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1259
1260
    # For DP attention
    global_num_tokens: Optional[List[int]]
1261
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1262

1263
    # For extend
1264
    extend_num_tokens: Optional[int]
1265
1266
1267
1268
1269
1270
1271
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1272
1273
1274
1275
1276
1277
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1278
1279
1280
1281
1282
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1283

Rin Intachuen's avatar
Rin Intachuen committed
1284
1285
1286
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1287
    # Speculative decoding
1288
    spec_algorithm: SpeculativeAlgorithm = None
1289
    spec_info: Optional[SpecInfo] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1290
    capture_hidden_mode: CaptureHiddenMode = None
1291

1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )