schedule_batch.py 42.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
16
17
18
19
20
21
22
23
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
24
25
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
26
27
28
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
29

30
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
31
import logging
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34

import torch
35
36
import triton
import triton.language as tl
37

Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.global_config import global_config
39
from sglang.srt.configs.model_config import ModelConfig
40
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
41
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
44
from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
46
from sglang.srt.sampling.sampling_params import SamplingParams
47
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
48
49

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
50

51
52
# Put some global args for easy access
global_server_args_dict = {
53
54
55
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
56
    "disable_mla": ServerArgs.disable_mla,
57
    "torchao_config": ServerArgs.torchao_config,
58
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
59
    "enable_dp_attention": ServerArgs.enable_dp_attention,
60
61
}

Lianmin Zheng's avatar
Lianmin Zheng committed
62

Ying Sheng's avatar
Ying Sheng committed
63
64
65
logger = logging.getLogger(__name__)


66
67
68
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
69

70
    def to_json(self):
71
        raise NotImplementedError()
72
73
74


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
75
    def __init__(self, matched: Union[int, List[int]]):
76
77
78
        super().__init__()
        self.matched = matched

79
80
81
82
83
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
84
85


86
87
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
88
        super().__init__()
89
        self.matched = matched
90

91
92
93
94
95
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
96
97


98
99
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
100
        super().__init__()
101
        self.length = length
102

103
104
105
106
107
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
108
109
110


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
111
    def __init__(self, message="Unknown error"):
112
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        self.message = message
114

115
116
117
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
118
            "message": self.message,
119
        }
120

Lianmin Zheng's avatar
Lianmin Zheng committed
121

122
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
123
class ImageInputs:
124
125
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
126
    pixel_values: torch.Tensor
127
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
128
129
130
131
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
132
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
135
136

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
137

Yineng Zhang's avatar
Yineng Zhang committed
138
139
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
140
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
143
144
145
146

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
147
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
148
        )
149
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
152
153
154
155
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
156
157
158
159
160
161
162
163
164
165
166
167

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
168
169
170
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
171
class Req:
172
    """The input and output status of a request."""
173

174
175
176
177
178
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
179
        sampling_params: SamplingParams,
180
        lora_path: Optional[str] = None,
Rin Intachuen's avatar
Rin Intachuen committed
181
        input_embeds: Optional[List[List[float]]] = None,
182
        session_id: Optional[str] = None,
183
    ):
184
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
185
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
186
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
187
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
188
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
189
        self.output_ids = []  # Each decode stage's output ids
190
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
191
192
        self.session_id = session_id

193
        self.sampling_params = sampling_params
194
        self.lora_path = lora_path
Rin Intachuen's avatar
Rin Intachuen committed
195
        self.input_embeds = input_embeds
Liangsheng Yin's avatar
Liangsheng Yin committed
196

197
        # Memory pool info
198
199
        self.req_pool_idx = None

200
201
202
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
203
        self.stream = False
204

205
        # For incremental decoding
206
207
208
209
210
211
212
213
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
214
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
218

219
220
221
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
222

223
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
224
        self.image_inputs: Optional[ImageInputs] = None
225

226
227
        # Prefix info
        self.prefix_indices = []
228
        self.extend_input_len = 0
229
        self.last_node = None
230
        self.is_being_chunked = 0
231

232
233
234
        # For retraction
        self.is_retracted = False

235
        # Logprobs (arguments)
236
237
238
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
239
240

        # Logprobs (return value)
241
        self.normalized_prompt_logprob = None
242
243
244
245
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
246
247

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
248
249
250
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
251
252
253
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

254
        # Embedding (return values)
255
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
256

257
        # Constrained decoding
258
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
259

260
261
262
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

263
264
265
266
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

267
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
268
        self.fill_ids = self.origin_input_ids + self.output_ids
269
270
271
272
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
273
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
274

275
    def adjust_max_prefix_ids(self):
276
277
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
278
279
280
281

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
285
286

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

287
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
288
289
290
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
291
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
292

293
        max_prefix_len = max(max_prefix_len, 0)
294
        return self.fill_ids[:max_prefix_len]
295

Liangsheng Yin's avatar
Liangsheng Yin committed
296
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
297
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
298
299
300
301
302
303
304
305
306
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
307
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
308

309
    def get_next_inc_detokenization(self):
310
311
        if self.tokenizer is None:
            return False, ""
312
313
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
319
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
320
321
322
323
324
325
326
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
327
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
328
329

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
330

331
    def check_finished(self):
332
        if self.finished():
333
334
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
335
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
336
337
338
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
339
340
            return

341
        last_token_id = self.output_ids[-1]
342

343
        matched_eos = False
344

345
346
347
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
348
349
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
350
351
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
352
        if matched_eos and not self.sampling_params.ignore_eos:
353
354
355
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

356
        # Check stop strings
357
358
359
360
361
362
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
363
                if stop_str in tail_str or stop_str in self.decoded_text:
364
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
365
366
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
367
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
368
369
370
371
372
373
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
374
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
375
        all_ids = self.tokenizer.encode(all_text)
376
        if not all_ids:
havetc's avatar
havetc committed
377
            logger.warning("Encoded all_text resulted in empty all_ids")
378
379
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
380
        prompt_tokens = len(self.origin_input_ids_unpadded)
381
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
382
            logger.warning("prompt_tokens is larger than encoded all_ids")
383
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
384
385
386

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
387
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
406

407
408
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
411
412

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
413
414
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
415
416
417
                    k = k + 1
                else:
                    break
418
419
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
420
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
421
            self.last_update_decode_tokens = len(self.output_ids) - k
422

Liangsheng Yin's avatar
Liangsheng Yin committed
423
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
424

Lianmin Zheng's avatar
Lianmin Zheng committed
425
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
426
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
427
428


429
430
431
bid = 0


432
@dataclasses.dataclass
433
class ScheduleBatch:
434
    """Store all inforamtion of a batch on the scheduler."""
435

436
    # Request, memory pool, and cache
437
    reqs: List[Req]
438
439
440
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
441

442
    # Batch configs
443
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
444
    forward_mode: ForwardMode = None
445
446
447
    enable_overlap: bool = False

    # Sampling info
448
    sampling_info: SamplingBatchInfo = None
449
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
450

451
    # Batched arguments to model runner
452
    input_ids: torch.Tensor = None
Rin Intachuen's avatar
Rin Intachuen committed
453
    input_embeds: torch.Tensor = None
454
455
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
456
    # The output locations of the KV cache
457
    out_cache_loc: torch.Tensor = None
458
459
    output_ids: torch.Tensor = None

460
461
462
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
463
464
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
465
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
466

467
    # For processing logprobs
468
    return_logprob: bool = False
469
470
471
472
473
474
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
475
    decoding_reqs: List[Req] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
476
    extend_logprob_start_lens: List[int] = None
477

478
479
480
481
482
483
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

484
485
486
    # Stream
    has_stream: bool = False

487
488
    # Has grammar
    has_grammar: bool = False
489

490
491
492
    # device
    device: str = "cuda"

493
    @classmethod
494
495
    def init_new(
        cls,
496
        reqs: List[Req],
497
498
499
500
501
        req_to_token_pool: ReqToTokenPool,
        token_to_kv_pool: ReqToTokenPool,
        tree_cache: BasePrefixCache,
        model_config: ModelConfig,
        enable_overlap: bool,
502
    ):
503
504
505
506
507
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
508
            model_config=model_config,
509
            enable_overlap=enable_overlap,
510
511
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
512
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
513
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
514
515
        )

516
    def batch_size(self):
517
        return len(self.reqs)
518

Lianmin Zheng's avatar
Lianmin Zheng committed
519
520
521
    def is_empty(self):
        return len(self.reqs) == 0

522
    def alloc_req_slots(self, num_reqs: int):
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
540
541
542
543
544
545
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
546
547
548
549
550
551
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
582
                # NOTE: the encoder part should be considered as a whole
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
608
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
609
610
611
612
613
614
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
615
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
616
617
618
619
620
621
622
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

623
    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
624
625
        self.forward_mode = ForwardMode.EXTEND

626
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
627
        reqs = self.reqs
628
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
629
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
630
        seq_lens = []
631
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
632

633
        # Allocate memory
634
        req_pool_indices = self.alloc_req_slots(bs)
635
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
636

Rin Intachuen's avatar
Rin Intachuen committed
637
638
639
        input_embeds = []

        pt = 0
640
        for i, req in enumerate(reqs):
641
642
643
644
645
646
647
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

648
            req.req_pool_idx = req_pool_indices[i]
649
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
650
            seq_lens.append(seq_len)
651
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
652

653
            if pre_len > 0:
654
655
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
656
                )
657

Rin Intachuen's avatar
Rin Intachuen committed
658
659
660
661
662
            # If input_embeds are available, store them
            if req.input_embeds is not None:
                # If req.input_embeds is already a list, append its content directly
                input_embeds.extend(req.input_embeds)  # Use extend to avoid nesting

663
664
665
666
667
668
669
670
671
            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
672
            req.is_retracted = False
673
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675

        # Set fields
676
677
678
679
680
681
682
683
684
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Rin Intachuen's avatar
Rin Intachuen committed
685
686
687
688
689
690
        self.input_embeds = (
            torch.tensor(input_embeds).to(self.device, non_blocking=True)
            if input_embeds
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
691
        self.out_cache_loc = out_cache_loc
692
693

        self.seq_lens_sum = sum(seq_lens)
694
695
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
696
        self.extend_num_tokens = extend_num_tokens
697
698
699
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        write_req_to_token_pool_triton[(bs,)](
            self.req_to_token_pool.req_to_token,
            self.req_pool_indices,
            pre_lens,
            self.seq_lens,
            extend_lens,
            self.out_cache_loc,
            self.req_to_token_pool.req_to_token.shape[1],
        )
        # The triton kernel is equivalent to the following python code.
        # self.req_to_token_pool.write(
        #    (req.req_pool_idx, slice(pre_len, seq_len)),
        #    out_cache_loc[pt : pt + req.extend_input_len],
        # )
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

724
725
726
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

727
        # Build sampling info
728
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
729
730
            self,
            self.model_config.vocab_size,
731
            enable_overlap_schedule=self.enable_overlap,
732
        )
733

734
    def mix_with_running(self, running_batch: "ScheduleBatch"):
735
        self.forward_mode = ForwardMode.MIXED
736
        running_bs = running_batch.batch_size()
737
738
739
740
741

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

742
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
743
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
744

745
        self.merge_batch(running_batch)
746
747
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
748

749
750
751
        # For overlap scheduler, the output_ids has one step delay
        delta = 0 if self.enable_overlap else -1

752
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
753
        self.prefix_lens.extend(
754
            [
755
                len(r.origin_input_ids) + len(r.output_ids) + delta
756
757
758
                for r in running_batch.reqs
            ]
        )
759
        self.extend_lens.extend([1] * running_bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
        self.extend_num_tokens += running_bs
        # TODO (lianmin): Revisit this. It should be seq_len - 1
762
        self.extend_logprob_start_lens.extend([0] * running_bs)
763

764
    def check_decode_mem(self):
765
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
766
        if self.token_to_kv_pool.available_size() >= bs:
767
768
            return True

Mingyi's avatar
Mingyi committed
769
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
770

771
772
773
774
775
776
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
777
        """Retract the decoding requests when there is not enough memory."""
778
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
779
780

        # TODO(lsyin): improve retraction policy for radix cache
781
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
782
783
784
785
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
786
787
788
789
            reverse=True,
        )

        retracted_reqs = []
790
        seq_lens_cpu = self.seq_lens.cpu().numpy()
791
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
792
793
794
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
795
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
796
797
798
799
800
801
802
803
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

804
            first_iter = False
805
806
807
808
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

809
810
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
811
812
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
813
                ]
814
                self.token_to_kv_pool.free(token_indices)
815
                self.req_to_token_pool.free(req.req_pool_idx)
816
817
818
819
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
820
821
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
822
                ]
823
                self.token_to_kv_pool.free(token_indices)
824
                self.req_to_token_pool.free(req.req_pool_idx)
825
826
827
828
829
830
831
832
833
834
835

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
836

837
            req.prefix_indices = []
838
            req.last_node = None
839
            req.extend_input_len = 0
840
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
841
842
843
844

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
845

846
        self.filter_batch(keep_indices=sorted_indices)
847

Liangsheng Yin's avatar
Liangsheng Yin committed
848
849
850
851
852
853
854
855
856
857
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
858

859
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
860
        jump_forward_reqs = []
861
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
862
863

        for i, req in enumerate(self.reqs):
864
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
865
866
867
868
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
869
870
871
872
873
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
874
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
875
876
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
877
878
                        continue

sglang's avatar
sglang committed
879
880
881
                    (
                        jump_forward_str,
                        next_state,
882
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
883

Lianmin Zheng's avatar
Lianmin Zheng committed
884
885
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
886
887
888
889
890
891
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
892

893
894
895
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
896
                    # insert the old request into tree_cache
897
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
898

Liangsheng Yin's avatar
Liangsheng Yin committed
899
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
900
                    if req.image_inputs is not None:
901
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
902
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
903
904
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
905
                    jump_forward_reqs.append(req)
906
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
907

908
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
909

Liangsheng Yin's avatar
Liangsheng Yin committed
910
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
911

912
913
914
915
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
916
917
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
918
919
920
921
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
922
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
923
924
        self.extend_num_tokens = 0

925
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
926
927
        self.forward_mode = ForwardMode.DECODE

928
929
        self.input_ids = self.output_ids
        self.output_ids = None
930
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
931
932

        # Alloc mem
933
        bs = len(self.reqs)
934
        self.out_cache_loc = self.alloc_token_slots(bs)
935

936
937
938
939
940
941
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

942
        if self.enable_overlap:
943
944
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
945
                (self.req_pool_indices, locs), self.out_cache_loc
946
947
948
949
950
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
951
                (self.req_pool_indices, locs), self.out_cache_loc
952
953
            )
            self.seq_lens.add_(1)
954
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
955

956
957
    def filter_batch(
        self,
958
        being_chunked_req: Optional[Req] = None,
959
960
961
962
963
964
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
965
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
966
967
968
            ]

        if keep_indices is None or len(keep_indices) == 0:
969
970
971
972
            # Filter out all requests
            self.reqs = []
            return

973
        if len(keep_indices) == len(self.reqs):
974
975
976
            # No need to filter
            return

977
978
979
980
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

981
        self.reqs = [self.reqs[i] for i in keep_indices]
982
983
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
984
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
985
        self.req_pool_indices = self.req_pool_indices[new_indices]
986
        self.seq_lens = self.seq_lens[new_indices]
987
        self.out_cache_loc = None
988
        self.seq_lens_sum = self.seq_lens.sum().item()
989
        self.output_ids = self.output_ids[new_indices]
990
        self.return_logprob = any(req.return_logprob for req in self.reqs)
991
        if self.return_logprob:
992
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
993
994
        else:
            self.top_logprobs_nums = None
995

996
        self.has_stream = any(req.stream for req in self.reqs)
997
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
998

999
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
1000

1001
    def merge_batch(self, other: "ScheduleBatch"):
1002
1003
1004
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
1005
        self.sampling_info.merge_batch(other.sampling_info)
1006

1007
1008
1009
1010
1011
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
1012
1013
1014
1015
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
1016
        self.out_cache_loc = None
1017
        self.seq_lens_sum += other.seq_lens_sum
1018
1019
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
1020
1021
1022
1023
1024
1025
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1026
        self.reqs.extend(other.reqs)
1027

1028
        self.return_logprob = self.return_logprob or other.return_logprob
1029
        self.has_stream = self.has_stream or other.has_stream
1030
        self.has_grammar = self.has_grammar or other.has_grammar
1031
1032

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1033
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1034
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1035
1036
1037
1038
1039
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1040
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1041
1042
1043
1044
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1045

1046
1047
1048
        global bid
        bid += 1

1049
        return ModelWorkerBatch(
1050
            bid=bid,
1051
1052
1053
1054
1055
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1056
            seq_lens_sum=self.seq_lens_sum,
1057
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1058
1059
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1060
            global_num_tokens=self.global_num_tokens,
1061
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1062
            extend_num_tokens=self.extend_num_tokens,
1063
1064
1065
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1066
1067
1068
1069
1070
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1071
            lora_paths=[req.lora_path for req in self.reqs],
1072
            sampling_info=self.sampling_info,
Rin Intachuen's avatar
Rin Intachuen committed
1073
            input_embeds=self.input_embeds,
1074
1075
        )

1076
    def copy(self):
1077
        # Only contain fields that will be used by process_batch_result
1078
1079
        return ScheduleBatch(
            reqs=self.reqs,
1080
            model_config=self.model_config,
1081
            forward_mode=self.forward_mode,
1082
1083
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1084
            decoding_reqs=self.decoding_reqs,
1085
1086
1087
1088
1089
1090
1091
1092
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1093

1094
@dataclasses.dataclass
1095
class ModelWorkerBatch:
1096
1097
    # The batch id
    bid: int
1098
1099
1100
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1101
    input_ids: torch.Tensor
1102
1103
1104
1105
1106
1107
1108
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1109
1110
1111
    # The sum of all sequence lengths
    seq_lens_sum: int

1112
1113
1114
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1115
1116
1117
1118
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1119
1120
    # For DP attention
    global_num_tokens: Optional[List[int]]
1121
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1122

1123
    # For extend
1124
    extend_num_tokens: Optional[int]
1125
1126
1127
1128
1129
1130
1131
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1132
1133
1134
1135
1136
1137
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1138
1139
1140
1141
1142
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1143

Rin Intachuen's avatar
Rin Intachuen committed
1144
1145
1146
    # The input Embeds
    input_embeds: Optional[torch.tensor] = None

1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )