schedule_batch.py 40.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37
38
import triton
import triton.language as tl
39

Liangsheng Yin's avatar
Liangsheng Yin committed
40
from sglang.global_config import global_config
41
from sglang.srt.configs.model_config import ModelConfig
42
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
43
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
44
from sglang.srt.mem_cache.chunk_cache import ChunkCache
45
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
46
from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
48
from sglang.srt.sampling.sampling_params import SamplingParams
49
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
50
51

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
52

53
54
# Put some global args for easy access
global_server_args_dict = {
55
56
57
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
58
    "disable_mla": ServerArgs.disable_mla,
59
    "torchao_config": ServerArgs.torchao_config,
60
    "enable_nan_detection": ServerArgs.enable_nan_detection,
Ke Bao's avatar
Ke Bao committed
61
    "enable_dp_attention": ServerArgs.enable_dp_attention,
62
63
}

Lianmin Zheng's avatar
Lianmin Zheng committed
64

Ying Sheng's avatar
Ying Sheng committed
65
66
67
logger = logging.getLogger(__name__)


68
69
70
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
71

72
    def to_json(self):
73
        raise NotImplementedError()
74
75
76


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
77
    def __init__(self, matched: Union[int, List[int]]):
78
79
80
        super().__init__()
        self.matched = matched

81
82
83
84
85
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
86
87


88
89
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
90
        super().__init__()
91
        self.matched = matched
92

93
94
95
96
97
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
98
99


100
101
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
102
        super().__init__()
103
        self.length = length
104

105
106
107
108
109
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
110
111
112


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
113
    def __init__(self, message="Unknown error"):
114
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        self.message = message
116

117
118
119
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
120
            "message": self.message,
121
        }
122

Lianmin Zheng's avatar
Lianmin Zheng committed
123

124
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
125
class ImageInputs:
126
127
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
128
    pixel_values: torch.Tensor
129
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
134
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
137
138

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
139

Yineng Zhang's avatar
Yineng Zhang committed
140
141
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
142
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
145
146
147
148

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
149
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
150
        )
151
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
154
155
156
157
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
158
159
160
161
162
163
164
165
166
167
168
169

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
172
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
173
class Req:
174
    """The input and output status of a request."""
175

176
177
178
179
180
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
181
        sampling_params: SamplingParams,
182
        lora_path: Optional[str] = None,
183
        session_id: Optional[str] = None,
184
    ):
185
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
186
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
187
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
188
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
189
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
190
        self.output_ids = []  # Each decode stage's output ids
191
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
192
193
        self.session_id = session_id

194
        self.sampling_params = sampling_params
195
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
196

197
        # Memory pool info
198
199
        self.req_pool_idx = None

200
201
202
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
203
        self.stream = False
204

205
        # For incremental decoding
206
207
208
209
210
211
212
213
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
214
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
218

219
220
221
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
222

223
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
224
        self.image_inputs: Optional[ImageInputs] = None
225

226
227
        # Prefix info
        self.prefix_indices = []
228
        self.extend_input_len = 0
229
        self.last_node = None
230
        self.is_being_chunked = 0
231

232
233
234
        # For retraction
        self.is_retracted = False

235
        # Logprobs (arguments)
236
237
238
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
239
240

        # Logprobs (return value)
241
        self.normalized_prompt_logprob = None
242
243
244
245
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
246
247

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
248
249
250
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
251
252
253
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

254
        # Embedding (return values)
255
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
256

257
        # Constrained decoding
258
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
259

260
261
262
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

263
264
265
266
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

267
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
268
        self.fill_ids = self.origin_input_ids + self.output_ids
269
270
271
272
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
273
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
274

275
    def adjust_max_prefix_ids(self):
276
277
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
278
279
280
281

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
285
286

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

287
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
288
289
290
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
291
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
292

293
        max_prefix_len = max(max_prefix_len, 0)
294
        return self.fill_ids[:max_prefix_len]
295

Liangsheng Yin's avatar
Liangsheng Yin committed
296
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
297
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
298
299
300
301
302
303
304
305
306
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
307
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
308

309
    def get_next_inc_detokenization(self):
310
311
        if self.tokenizer is None:
            return False, ""
312
313
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
319
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
320
321
322
323
324
325
326
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
327
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
328
329

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
330

331
    def check_finished(self):
332
        if self.finished():
333
334
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
335
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
336
337
338
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
339
340
            return

341
        last_token_id = self.output_ids[-1]
342

343
        matched_eos = False
344

345
346
347
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
348
349
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
350
351
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
352
        if matched_eos and not self.sampling_params.ignore_eos:
353
354
355
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

356
        # Check stop strings
357
358
359
360
361
362
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
363
                if stop_str in tail_str or stop_str in self.decoded_text:
364
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
365
366
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
367
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
368
369
370
371
372
373
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
374
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
375
        all_ids = self.tokenizer.encode(all_text)
376
        if not all_ids:
havetc's avatar
havetc committed
377
            logger.warning("Encoded all_text resulted in empty all_ids")
378
379
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
380
        prompt_tokens = len(self.origin_input_ids_unpadded)
381
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
382
            logger.warning("prompt_tokens is larger than encoded all_ids")
383
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
384
385
386

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
387
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
406

407
408
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
411
412

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
413
414
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
415
416
417
                    k = k + 1
                else:
                    break
418
419
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
420
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
421
            self.last_update_decode_tokens = len(self.output_ids) - k
422

Liangsheng Yin's avatar
Liangsheng Yin committed
423
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
424

Lianmin Zheng's avatar
Lianmin Zheng committed
425
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
426
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
427
428


429
430
431
bid = 0


432
@dataclasses.dataclass
433
class ScheduleBatch:
434
    """Store all inforamtion of a batch on the scheduler."""
435

436
    # Request, memory pool, and cache
437
    reqs: List[Req]
438
439
440
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
441
442
443

    # For utility
    model_config: ModelConfig = None
Liangsheng Yin's avatar
Liangsheng Yin committed
444
    forward_mode: ForwardMode = None
445
    sampling_info: SamplingBatchInfo = None
446
    next_batch_sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
447

448
    # Batched arguments to model runner
449
450
451
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
452
    # The output locations of the KV cache
453
    out_cache_loc: torch.Tensor = None
454
455
    output_ids: torch.Tensor = None

456
457
458
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
459
460
    # For DP attention
    global_num_tokens: Optional[List[int]] = None
461
    can_run_dp_cuda_graph: bool = False
Ke Bao's avatar
Ke Bao committed
462

463
    # For processing logprobs
464
    return_logprob: bool = False
465
466
467
468
469
470
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
471
    decoding_reqs: List[Req] = None
472

473
474
475
476
477
478
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

479
480
481
    # Stream
    has_stream: bool = False

482
483
    # Has grammar
    has_grammar: bool = False
484

485
486
487
    # device
    device: str = "cuda"

488
    @classmethod
489
490
    def init_new(
        cls,
491
        reqs: List[Req],
492
493
494
495
496
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
497
498
499
500
501
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
502
            model_config=model_config,
503
504
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
505
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
506
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
        )

509
    def batch_size(self):
510
        return len(self.reqs)
511

Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
514
    def is_empty(self):
        return len(self.reqs) == 0

515
    def alloc_req_slots(self, num_reqs: int):
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
533
534
535
536
537
538
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
539
540
541
542
543
544
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
575
                # NOTE: the encoder part should be considered as a whole
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
601
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
602
603
604
605
606
607
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
608
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
609
610
611
612
613
614
615
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

616
    def prepare_for_extend(self, enable_overlap_schedule: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
617
618
        self.forward_mode = ForwardMode.EXTEND

619
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
620
        reqs = self.reqs
621
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
622
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
623
        seq_lens = []
624
        pre_lens = []
Lianmin Zheng's avatar
Lianmin Zheng committed
625

626
        # Allocate memory
627
        req_pool_indices = self.alloc_req_slots(bs)
628
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
629

630
        for i, req in enumerate(reqs):
631
632
633
634
635
636
637
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

638
            req.req_pool_idx = req_pool_indices[i]
639
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
640
            seq_lens.append(seq_len)
641
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
642

643
            if pre_len > 0:
644
645
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
646
                )
647
648
649
650
651
652
653
654
655
656

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
657
            req.is_retracted = False
658
            pre_lens.append(pre_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
659
660

        # Set fields
661
662
663
664
665
666
667
668
669
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
670
        self.out_cache_loc = out_cache_loc
671
672

        self.seq_lens_sum = sum(seq_lens)
673
674
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
675
        self.extend_num_tokens = extend_num_tokens
676
677
678
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
679

680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        # Write to req_to_token_pool
        pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        write_req_to_token_pool_triton[(bs,)](
            self.req_to_token_pool.req_to_token,
            self.req_pool_indices,
            pre_lens,
            self.seq_lens,
            extend_lens,
            self.out_cache_loc,
            self.req_to_token_pool.req_to_token.shape[1],
        )
        # The triton kernel is equivalent to the following python code.
        # self.req_to_token_pool.write(
        #    (req.req_pool_idx, slice(pre_len, seq_len)),
        #    out_cache_loc[pt : pt + req.extend_input_len],
        # )
        # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

703
704
705
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

706
        # Build sampling info
707
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
708
709
            self,
            self.model_config.vocab_size,
710
            enable_overlap_schedule=enable_overlap_schedule,
711
        )
712

713
    def mix_with_running(self, running_batch: "ScheduleBatch"):
714
        self.forward_mode = ForwardMode.MIXED
715
        running_bs = running_batch.batch_size()
716
717
718
719
720

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

721
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
722
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
723

724
        self.merge_batch(running_batch)
725
726
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
727
        self.extend_num_tokens += running_bs
728
729

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
730
        self.prefix_lens.extend(
731
732
733
734
735
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
736
737
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
738

739
    def check_decode_mem(self):
740
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
741
        if self.token_to_kv_pool.available_size() >= bs:
742
743
            return True

Mingyi's avatar
Mingyi committed
744
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
745

746
747
748
749
750
751
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
752
        """Retract the decoding requests when there is not enough memory."""
753
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
754
755

        # TODO(lsyin): improve retraction policy for radix cache
756
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
757
758
759
760
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
761
762
763
764
            reverse=True,
        )

        retracted_reqs = []
765
        seq_lens_cpu = self.seq_lens.cpu().numpy()
766
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
767
768
769
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
770
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
771
772
773
774
775
776
777
778
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

779
            first_iter = False
780
781
782
783
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

784
785
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
786
787
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
788
                ]
789
                self.token_to_kv_pool.free(token_indices)
790
                self.req_to_token_pool.free(req.req_pool_idx)
791
792
793
794
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
795
796
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
797
                ]
798
                self.token_to_kv_pool.free(token_indices)
799
                self.req_to_token_pool.free(req.req_pool_idx)
800
801
802
803
804
805
806
807
808
809
810

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
811

812
            req.prefix_indices = []
813
            req.last_node = None
814
            req.extend_input_len = 0
815
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
816
817
818
819

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
820

821
        self.filter_batch(keep_indices=sorted_indices)
822

Liangsheng Yin's avatar
Liangsheng Yin committed
823
824
825
826
827
828
829
830
831
832
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
833

834
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
835
        jump_forward_reqs = []
836
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
837
838

        for i, req in enumerate(self.reqs):
839
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
840
841
842
843
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
844
845
846
847
848
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
849
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
850
851
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
852
853
                        continue

sglang's avatar
sglang committed
854
855
856
                    (
                        jump_forward_str,
                        next_state,
857
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
858

Lianmin Zheng's avatar
Lianmin Zheng committed
859
860
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
861
862
863
864
865
866
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
867

868
869
870
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
871
                    # insert the old request into tree_cache
872
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
873

Liangsheng Yin's avatar
Liangsheng Yin committed
874
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
875
                    if req.image_inputs is not None:
876
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
877
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
878
879
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
880
                    jump_forward_reqs.append(req)
881
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
882

883
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
884

Liangsheng Yin's avatar
Liangsheng Yin committed
885
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
886

887
888
889
890
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
891
892
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
893
894
895
896
        self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
        self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
        self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
        self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
897
        self.seq_lens_sum = 0
Ke Bao's avatar
Ke Bao committed
898
899
        self.extend_num_tokens = 0

900
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
901
902
        self.forward_mode = ForwardMode.DECODE

903
904
        self.input_ids = self.output_ids
        self.output_ids = None
905
        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907

        # Alloc mem
908
        bs = len(self.reqs)
909
        self.out_cache_loc = self.alloc_token_slots(bs)
910

911
912
913
914
915
916
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

917
918
919
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
920
                (self.req_pool_indices, locs), self.out_cache_loc
921
922
923
924
925
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
926
                (self.req_pool_indices, locs), self.out_cache_loc
927
928
            )
            self.seq_lens.add_(1)
929
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
930

931
932
    def filter_batch(
        self,
933
        being_chunked_req: Optional[Req] = None,
934
935
936
937
938
939
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
940
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
941
942
943
            ]

        if keep_indices is None or len(keep_indices) == 0:
944
945
946
947
            # Filter out all requests
            self.reqs = []
            return

948
        if len(keep_indices) == len(self.reqs):
949
950
951
            # No need to filter
            return

952
953
954
955
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

956
        self.reqs = [self.reqs[i] for i in keep_indices]
957
958
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
959
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
960
        self.req_pool_indices = self.req_pool_indices[new_indices]
961
        self.seq_lens = self.seq_lens[new_indices]
962
        self.out_cache_loc = None
963
        self.seq_lens_sum = self.seq_lens.sum().item()
964
        self.output_ids = self.output_ids[new_indices]
965
        self.return_logprob = any(req.return_logprob for req in self.reqs)
966
        if self.return_logprob:
967
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
968
969
        else:
            self.top_logprobs_nums = None
970

971
        self.has_stream = any(req.stream for req in self.reqs)
972
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
973

974
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
975

976
    def merge_batch(self, other: "ScheduleBatch"):
977
978
979
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
980
        self.sampling_info.merge_batch(other.sampling_info)
981

982
983
984
985
986
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
987
988
989
990
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
991
        self.out_cache_loc = None
992
        self.seq_lens_sum += other.seq_lens_sum
993
994
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
995
996
997
998
999
1000
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1001
        self.reqs.extend(other.reqs)
1002

1003
        self.return_logprob = self.return_logprob or other.return_logprob
1004
        self.has_stream = self.has_stream or other.has_stream
1005
        self.has_grammar = self.has_grammar or other.has_grammar
1006
1007

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
1008
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1009
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1010
1011
1012
1013
1014
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

1015
        if self.sampling_info:
Ke Bao's avatar
Ke Bao committed
1016
1017
1018
1019
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
1020

1021
1022
1023
        global bid
        bid += 1

1024
        return ModelWorkerBatch(
1025
            bid=bid,
1026
1027
1028
1029
1030
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1031
            seq_lens_sum=self.seq_lens_sum,
1032
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1033
1034
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1035
            global_num_tokens=self.global_num_tokens,
1036
            can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1037
            extend_num_tokens=self.extend_num_tokens,
1038
1039
1040
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1041
1042
1043
1044
1045
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1046
            lora_paths=[req.lora_path for req in self.reqs],
1047
1048
1049
            sampling_info=self.sampling_info,
        )

1050
    def copy(self):
1051
        # Only contain fields that will be used by process_batch_result
1052
1053
        return ScheduleBatch(
            reqs=self.reqs,
1054
            model_config=self.model_config,
1055
            forward_mode=self.forward_mode,
1056
1057
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1058
            decoding_reqs=self.decoding_reqs,
1059
1060
1061
1062
1063
1064
1065
1066
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1067

1068
@dataclasses.dataclass
1069
class ModelWorkerBatch:
1070
1071
    # The batch id
    bid: int
1072
1073
1074
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1075
    input_ids: torch.Tensor
1076
1077
1078
1079
1080
1081
1082
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1083
1084
1085
    # The sum of all sequence lengths
    seq_lens_sum: int

1086
1087
1088
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1089
1090
1091
1092
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1093
1094
    # For DP attention
    global_num_tokens: Optional[List[int]]
1095
    can_run_dp_cuda_graph: bool
Ke Bao's avatar
Ke Bao committed
1096

1097
    # For extend
1098
    extend_num_tokens: Optional[int]
1099
1100
1101
1102
1103
1104
1105
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1106
1107
1108
1109
1110
1111
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1112
1113
1114
1115
1116
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1117

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153

@triton.jit
def write_req_to_token_pool_triton(
    req_to_token_ptr,  # [max_batch, max_context_len]
    req_pool_indices,
    pre_lens,
    seq_lens,
    extend_lens,
    out_cache_loc,
    req_to_token_ptr_stride: tl.constexpr,
):
    BLOCK_SIZE: tl.constexpr = 512
    pid = tl.program_id(0)

    req_pool_index = tl.load(req_pool_indices + pid)
    pre_len = tl.load(pre_lens + pid)
    seq_len = tl.load(seq_lens + pid)

    # TODO: optimize this?
    cumsum_start = 0
    for i in range(pid):
        cumsum_start += tl.load(extend_lens + i)

    num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
    for i in range(num_loop):
        offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
        mask = offset < (seq_len - pre_len)
        value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
        tl.store(
            req_to_token_ptr
            + req_pool_index * req_to_token_ptr_stride
            + offset
            + pre_len,
            value,
            mask=mask,
        )