schedule_batch.py 40.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37
38
import triton
import triton.language as tl
39

Liangsheng Yin's avatar
Liangsheng Yin committed
40
from sglang.global_config import global_config
41
from sglang.srt.configs.model_config import ModelConfig
42
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
43
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
44
from sglang.srt.mem_cache.chunk_cache import ChunkCache
45
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
48
from sglang.srt.sampling.sampling_params import SamplingParams
49
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
50
51

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
52

53
54
# Put some global args for easy access
global_server_args_dict = {
55
56
57
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
58
    "disable_mla": ServerArgs.disable_mla,
59
    "torchao_config": ServerArgs.torchao_config,
60
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
61
    "enable_dp_attention": ServerArgs.enable_dp_attention,
62
63
}

Lianmin Zheng's avatar
Lianmin Zheng committed
64

Ying Sheng's avatar
Ying Sheng committed
65
66
67
logger = logging.getLogger(__name__)


68
69
70
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
71

72
    def to_json(self):
73
        raise NotImplementedError()
74
75
76


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
77
    def __init__(self, matched: Union[int, List[int]]):
78
79
80
        super().__init__()
        self.matched = matched

81
82
83
84
85
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
86
87


88
89
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
90
        super().__init__()
91
        self.matched = matched
92

93
94
95
96
97
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
98
99


100
101
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
102
        super().__init__()
103
        self.length = length
104

105
106
107
108
109
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
110
111
112


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    def __init__(self, message="Unknown error"):
114
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        self.message = message
116

117
118
119
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            "message": self.message,
121
        }
122

Lianmin Zheng's avatar
Lianmin Zheng committed
123

124
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
125
class ImageInputs:
126
127
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
128
    pixel_values: torch.Tensor
129
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
134
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
137
138

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
139

Yineng Zhang's avatar
Yineng Zhang committed
140
141
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
142
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
145
146
147
148

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
149
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
150
        )
151
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
154
155
156
157
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
158
159
160
161
162
163
164
165
166
167
168
169

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
172
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
173
class Req:
174
    """The input and output status of a request."""
175

176
177
178
179
180
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
181
        sampling_params: SamplingParams,
182
183
        lora_path: Optional[str] = None,
    ):
184
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
185
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
186
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
187
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
188
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
189
        self.output_ids = []  # Each decode stage's output ids
190
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
191
        self.sampling_params = sampling_params
192
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
193

194
        # Memory pool info
195
196
        self.req_pool_idx = None

197
198
199
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
200
        self.stream = False
201

202
        # For incremental decoding
203
204
205
206
207
208
209
210
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
211
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
212
213
214
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
215

216
217
218
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
219

220
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
221
        self.image_inputs: Optional[ImageInputs] = None
222

223
224
        # Prefix info
        self.prefix_indices = []
225
        self.extend_input_len = 0
226
        self.last_node = None
227
        self.is_being_chunked = 0
228

229
230
231
        # For retraction
        self.is_retracted = False

232
        # Logprobs (arguments)
233
234
235
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
236
237

        # Logprobs (return value)
238
        self.normalized_prompt_logprob = None
239
240
241
242
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
243
244

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
245
246
247
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
248
249
250
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

251
        # Embedding (return values)
252
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
253

254
        # Constrained decoding
255
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
256

257
258
259
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

260
261
262
263
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

264
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
265
        self.fill_ids = self.origin_input_ids + self.output_ids
266
267
268
269
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
270
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
271

272
    def adjust_max_prefix_ids(self):
273
274
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
275
276
277
278

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
279
280
281
282
283

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

284
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
285
286
287
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
288
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
289

290
        max_prefix_len = max(max_prefix_len, 0)
291
        return self.fill_ids[:max_prefix_len]
292

Liangsheng Yin's avatar
Liangsheng Yin committed
293
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
294
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
295
296
297
298
299
300
301
302
303
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
304
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
305

306
    def get_next_inc_detokenization(self):
307
308
        if self.tokenizer is None:
            return False, ""
309
310
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
311
312
313
314
315

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
316
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
317
318
319
320
321
322
323
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
324
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
325
326

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
327

328
    def check_finished(self):
329
        if self.finished():
330
331
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
332
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
333
334
335
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
336
337
            return

338
        last_token_id = self.output_ids[-1]
339

340
        matched_eos = False
341

342
343
344
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
345
346
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
347
348
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
349
        if matched_eos and not self.sampling_params.ignore_eos:
350
351
352
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

353
        # Check stop strings
354
355
356
357
358
359
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
360
                if stop_str in tail_str or stop_str in self.decoded_text:
361
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
362
363
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
364
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
365
366
367
368
369
370
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
371
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
372
        all_ids = self.tokenizer.encode(all_text)
373
        if not all_ids:
havetc's avatar
havetc committed
374
            logger.warning("Encoded all_text resulted in empty all_ids")
375
376
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
377
        prompt_tokens = len(self.origin_input_ids_unpadded)
378
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
379
            logger.warning("prompt_tokens is larger than encoded all_ids")
380
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
381
382
383

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
384
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
403

404
405
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
406
407
408
409

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
410
411
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
412
413
414
                    k = k + 1
                else:
                    break
415
416
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
417
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
418
            self.last_update_decode_tokens = len(self.output_ids) - k
419

Liangsheng Yin's avatar
Liangsheng Yin committed
420
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
421

Lianmin Zheng's avatar
Lianmin Zheng committed
422
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
423
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
424
425


426
427
428
bid = 0


429
@dataclasses.dataclass
430
class ScheduleBatch:
431
    """Store all inforamtion of a batch on the scheduler."""
432

433
    # Request, memory pool, and cache
434
    reqs: List[Req]
435
436
437
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
438
439
440

    # For utility
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
441
    forward_mode: ForwardMode = None
442
    sampling_info: SamplingBatchInfo = None
443
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
444

445
    # Batched arguments to model runner
446
447
448
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
449
    # The output locations of the KV cache
450
    out_cache_loc: torch.Tensor = None
451
452
    output_ids: torch.Tensor = None

453
454
455
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
456
457
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
458
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
459

460
    # For processing logprobs
461
    return_logprob: bool = False
462
463
464
465
466
467
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
468
    decoding_reqs: List[Req] = None
469

470
471
472
473
474
475
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

476
477
478
    # Stream
    has_stream: bool = False

479
480
    # Has grammar
    has_grammar: bool = False
481

482
483
484
    # device
    device: str = "cuda"

485
    @classmethod
486
487
    def init_new(
        cls,
488
        reqs: List[Req],
489
490
491
492
493
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
494
495
496
497
498
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
499
            model_config=model_config,
500
501
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
502
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
503
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
        )

506
    def batch_size(self):
507
        return len(self.reqs)
508

Lianmin Zheng's avatar
Lianmin Zheng committed
509
510
511
    def is_empty(self):
        return len(self.reqs) == 0

512
    def alloc_req_slots(self, num_reqs: int):
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
530
531
532
533
534
535
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
536
537
538
539
540
541
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
572
                # NOTE: the encoder part should be considered as a whole
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
598
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
599
600
601
602
603
604
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
605
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
606
607
608
609
610
611
612
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

613
    def prepare_for_extend(self, enable_overlap_schedule: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
614
615
        self.forward_mode = ForwardMode.EXTEND

616
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
617
        reqs = self.reqs
618
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
619
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
620
        seq_lens = []
621
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
622

623
        # Allocate memory
624
        req_pool_indices = self.alloc_req_slots(bs)
625
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
626

627
        for i, req in enumerate(reqs):
628
629
630
631
632
633
634
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

635
            req.req_pool_idx = req_pool_indices[i]
636
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
637
            seq_lens.append(seq_len)
638
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
639

640
            if pre_len > 0:
641
642
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
643
                )
644
645
646
647
648
649
650
651
652
653

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
654
            req.is_retracted = False
655
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
656
657

        # Set fields
658
659
660
661
662
663
664
665
666
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
667
        self.out_cache_loc = out_cache_loc
668
669

        self.seq_lens_sum = sum(seq_lens)
670
671
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
672
        self.extend_num_tokens = extend_num_tokens
673
674
675
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
676

677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        write_req_to_token_pool_triton[(bs,)](
            self.req_to_token_pool.req_to_token,
            self.req_pool_indices,
            pre_lens,
            self.seq_lens,
            extend_lens,
            self.out_cache_loc,
            self.req_to_token_pool.req_to_token.shape[1],
        )
        # The triton kernel is equivalent to the following python code.
        # self.req_to_token_pool.write(
        #    (req.req_pool_idx, slice(pre_len, seq_len)),
        #    out_cache_loc[pt : pt + req.extend_input_len],
        # )
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

700
701
702
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

703
        # Build sampling info
704
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
705
706
            self,
            self.model_config.vocab_size,
707
            enable_overlap_schedule=enable_overlap_schedule,
708
        )
709

710
    def mix_with_running(self, running_batch: "ScheduleBatch"):
711
        self.forward_mode = ForwardMode.MIXED
712
        running_bs = running_batch.batch_size()
713
714
715
716
717

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

718
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
719
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
720

721
        self.merge_batch(running_batch)
722
723
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
724
        self.extend_num_tokens += running_bs
725
726

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
727
        self.prefix_lens.extend(
728
729
730
731
732
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
733
734
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
735

736
    def check_decode_mem(self):
737
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
738
        if self.token_to_kv_pool.available_size() >= bs:
739
740
            return True

Mingyi's avatar
Mingyi committed
741
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
742

743
744
745
746
747
748
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
749
        """Retract the decoding requests when there is not enough memory."""
750
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
751
752

        # TODO(lsyin): improve retraction policy for radix cache
753
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
754
755
756
757
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
758
759
760
761
            reverse=True,
        )

        retracted_reqs = []
762
        seq_lens_cpu = self.seq_lens.cpu().numpy()
763
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
764
765
766
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
767
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
768
769
770
771
772
773
774
775
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

776
            first_iter = False
777
778
779
780
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

781
782
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
783
784
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
785
                ]
786
                self.token_to_kv_pool.free(token_indices)
787
                self.req_to_token_pool.free(req.req_pool_idx)
788
789
790
791
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
792
793
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
794
                ]
795
                self.token_to_kv_pool.free(token_indices)
796
                self.req_to_token_pool.free(req.req_pool_idx)
797
798
799
800
801
802
803
804
805
806
807

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
808

809
            req.prefix_indices = []
810
            req.last_node = None
811
            req.extend_input_len = 0
812
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
813
814
815
816

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
817

818
        self.filter_batch(keep_indices=sorted_indices)
819

Liangsheng Yin's avatar
Liangsheng Yin committed
820
821
822
823
824
825
826
827
828
829
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
830

831
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
832
        jump_forward_reqs = []
833
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
834
835

        for i, req in enumerate(self.reqs):
836
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
837
838
839
840
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
841
842
843
844
845
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
846
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
847
848
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
849
850
                        continue

sglang's avatar
sglang committed
851
852
853
                    (
                        jump_forward_str,
                        next_state,
854
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
855

Lianmin Zheng's avatar
Lianmin Zheng committed
856
857
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
858
859
860
861
862
863
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
864

865
866
867
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
868
                    # insert the old request into tree_cache
869
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
870

Liangsheng Yin's avatar
Liangsheng Yin committed
871
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
872
                    if req.image_inputs is not None:
873
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
874
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
875
876
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
877
                    jump_forward_reqs.append(req)
878
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
879

880
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
881

Liangsheng Yin's avatar
Liangsheng Yin committed
882
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
883

884
885
886
887
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
888
889
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
890
891
892
893
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
894
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
895
896
        self.extend_num_tokens = 0

897
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
898
899
        self.forward_mode = ForwardMode.DECODE

900
901
        self.input_ids = self.output_ids
        self.output_ids = None
902
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
903
904

        # Alloc mem
905
        bs = len(self.reqs)
906
        self.out_cache_loc = self.alloc_token_slots(bs)
907

908
909
910
911
912
913
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

914
915
916
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
917
                (self.req_pool_indices, locs), self.out_cache_loc
918
919
920
921
922
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
923
                (self.req_pool_indices, locs), self.out_cache_loc
924
925
            )
            self.seq_lens.add_(1)
926
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
927

928
929
    def filter_batch(
        self,
930
        being_chunked_req: Optional[Req] = None,
931
932
933
934
935
936
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
937
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
938
939
940
            ]

        if keep_indices is None or len(keep_indices) == 0:
941
942
943
944
            # Filter out all requests
            self.reqs = []
            return

945
        if len(keep_indices) == len(self.reqs):
946
947
948
            # No need to filter
            return

949
950
951
952
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

953
        self.reqs = [self.reqs[i] for i in keep_indices]
954
955
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
956
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
957
        self.req_pool_indices = self.req_pool_indices[new_indices]
958
        self.seq_lens = self.seq_lens[new_indices]
959
        self.out_cache_loc = None
960
        self.seq_lens_sum = self.seq_lens.sum().item()
961
        self.output_ids = self.output_ids[new_indices]
962
        self.return_logprob = any(req.return_logprob for req in self.reqs)
963
        if self.return_logprob:
964
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
965
966
        else:
            self.top_logprobs_nums = None
967

968
        self.has_stream = any(req.stream for req in self.reqs)
969
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
970

971
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
972

973
    def merge_batch(self, other: "ScheduleBatch"):
974
975
976
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
977
        self.sampling_info.merge_batch(other.sampling_info)
978

979
980
981
982
983
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
984
985
986
987
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
988
        self.out_cache_loc = None
989
        self.seq_lens_sum += other.seq_lens_sum
990
991
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
992
993
994
995
996
997
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
998
        self.reqs.extend(other.reqs)
999

1000
        self.return_logprob = self.return_logprob or other.return_logprob
1001
        self.has_stream = self.has_stream or other.has_stream
1002
        self.has_grammar = self.has_grammar or other.has_grammar
1003
1004

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1005
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1006
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1007
1008
1009
1010
1011
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1012
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1013
1014
1015
1016
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1017

1018
1019
1020
        global bid
        bid += 1

1021
        return ModelWorkerBatch(
1022
            bid=bid,
1023
1024
1025
1026
1027
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1028
            seq_lens_sum=self.seq_lens_sum,
1029
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1030
1031
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1032
            global_num_tokens=self.global_num_tokens,
1033
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1034
            extend_num_tokens=self.extend_num_tokens,
1035
1036
1037
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1038
1039
1040
1041
1042
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1043
            lora_paths=[req.lora_path for req in self.reqs],
1044
1045
1046
            sampling_info=self.sampling_info,
        )

1047
    def copy(self):
1048
        # Only contain fields that will be used by process_batch_result
1049
1050
        return ScheduleBatch(
            reqs=self.reqs,
1051
            model_config=self.model_config,
1052
            forward_mode=self.forward_mode,
1053
1054
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1055
            decoding_reqs=self.decoding_reqs,
1056
1057
1058
1059
1060
1061
1062
1063
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1064

1065
@dataclasses.dataclass
1066
class ModelWorkerBatch:
1067
1068
    # The batch id
    bid: int
1069
1070
1071
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1072
    input_ids: torch.Tensor
1073
1074
1075
1076
1077
1078
1079
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1080
1081
1082
    # The sum of all sequence lengths
    seq_lens_sum: int

1083
1084
1085
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1086
1087
1088
1089
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1090
1091
    # For DP attention
    global_num_tokens: Optional[List[int]]
1092
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1093

1094
    # For extend
1095
    extend_num_tokens: Optional[int]
1096
1097
1098
1099
1100
1101
1102
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1103
1104
1105
1106
1107
1108
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1109
1110
1111
1112
1113
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1114

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )