schedule_batch.py 39.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37

Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.global_config import global_config
39
from sglang.srt.configs.model_config import ModelConfig
40
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
41
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
44
from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
46
from sglang.srt.sampling.sampling_params import SamplingParams
47
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
48
49

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
50

51
52
# Put some global args for easy access
global_server_args_dict = {
53
54
55
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
56
    "disable_mla": ServerArgs.disable_mla,
57
    "torchao_config": ServerArgs.torchao_config,
58
    "disable_nan_detection": ServerArgs.disable_nan_detection,
Ke Bao's avatar
Ke Bao committed
59
    "enable_dp_attention": ServerArgs.enable_dp_attention,
60
61
}

Lianmin Zheng's avatar
Lianmin Zheng committed
62

Ying Sheng's avatar
Ying Sheng committed
63
64
65
logger = logging.getLogger(__name__)


66
67
68
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
69

70
    def to_json(self):
71
        raise NotImplementedError()
72
73
74


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
75
    def __init__(self, matched: Union[int, List[int]]):
76
77
78
        super().__init__()
        self.matched = matched

79
80
81
82
83
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
84
85


86
87
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
88
        super().__init__()
89
        self.matched = matched
90

91
92
93
94
95
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
96
97


98
99
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
100
        super().__init__()
101
        self.length = length
102

103
104
105
106
107
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
108
109
110


class FINISH_ABORT(BaseFinishReason):
Lianmin Zheng's avatar
Lianmin Zheng committed
111
    def __init__(self, message="Unknown error"):
112
        super().__init__(is_error=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        self.message = message
114

115
116
117
    def to_json(self):
        return {
            "type": "abort",
Lianmin Zheng's avatar
Lianmin Zheng committed
118
            "message": self.message,
119
        }
120

Lianmin Zheng's avatar
Lianmin Zheng committed
121

122
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
123
class ImageInputs:
124
125
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
126
    pixel_values: torch.Tensor
127
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
128
129
130
131
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
132
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
135
136

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
Yineng Zhang's avatar
Yineng Zhang committed
137
138
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
139
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
140
141
142
143
144
145

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
146
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
147
        )
148
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
149
150
151
152
153
154
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
155
156
157
158
159
160
161
162
163
164
165
166

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
169
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
170
class Req:
171
    """The input and output status of a request."""
172

173
174
175
176
177
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
178
        sampling_params: SamplingParams,
179
180
        lora_path: Optional[str] = None,
    ):
181
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
182
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
183
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
184
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
185
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
186
        self.output_ids = []  # Each decode stage's output ids
187
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
188
189

        self.sampling_params = sampling_params
190
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
191

192
193
194
        # Memory info
        self.req_pool_idx = None

195
196
197
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
198
        self.stream = False
199

200
        # For incremental decoding
201
202
203
204
205
206
207
208
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
209
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
210
211
212
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
213

214
215
216
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
217

218
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
219
        self.image_inputs: Optional[ImageInputs] = None
220

221
222
        # Prefix info
        self.prefix_indices = []
223
        self.extend_input_len = 0
224
        self.last_node = None
225
        self.is_being_chunked = 0
226

227
228
229
        # For retraction
        self.is_retracted = False

230
        # Logprobs (arguments)
231
232
233
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
234
235

        # Logprobs (return value)
236
        self.normalized_prompt_logprob = None
237
238
239
240
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
241
242

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
243
244
245
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
246
247
248
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

249
        # Embedding (return values)
250
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
251

252
        # Constrained decoding
253
        self.grammar: Optional[BaseGrammarObject] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
254

255
256
257
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

258
259
260
261
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

262
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
263
        self.fill_ids = self.origin_input_ids + self.output_ids
264
265
266
267
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
268
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
269

270
    def adjust_max_prefix_ids(self):
271
272
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
273
274
275
276

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
277
278
279
280
281

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

282
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
283
284
285
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
286
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
287

288
        max_prefix_len = max(max_prefix_len, 0)
289
        return self.fill_ids[:max_prefix_len]
290

Liangsheng Yin's avatar
Liangsheng Yin committed
291
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
292
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
293
294
295
296
297
298
299
300
301
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
302
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
303

304
    def get_next_inc_detokenization(self):
305
306
        if self.tokenizer is None:
            return False, ""
307
308
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
309
310
311
312
313

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
314
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
315
316
317
318
319
320
321
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
322
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
323
324

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
325

326
    def check_finished(self):
327
        if self.finished():
328
329
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
330
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
331
332
333
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
334
335
            return

336
        last_token_id = self.output_ids[-1]
337

338
        matched_eos = False
339

340
341
342
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
343
344
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
345
346
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
347
        if matched_eos and not self.sampling_params.ignore_eos:
348
349
350
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

351
        # Check stop strings
352
353
354
355
356
357
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
358
                if stop_str in tail_str or stop_str in self.decoded_text:
359
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
360
361
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
362
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
363
364
365
366
367
368
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
369
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
370
        all_ids = self.tokenizer.encode(all_text)
371
        if not all_ids:
havetc's avatar
havetc committed
372
            logger.warning("Encoded all_text resulted in empty all_ids")
373
374
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
375
        prompt_tokens = len(self.origin_input_ids_unpadded)
376
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
377
            logger.warning("prompt_tokens is larger than encoded all_ids")
378
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
379
380
381

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
382
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
401

402
403
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
404
405
406
407

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
408
409
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
410
411
412
                    k = k + 1
                else:
                    break
413
414
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
415
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
416
            self.last_update_decode_tokens = len(self.output_ids) - k
417

Liangsheng Yin's avatar
Liangsheng Yin committed
418
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
419

Lianmin Zheng's avatar
Lianmin Zheng committed
420
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
421
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
422
423


424
425
426
bid = 0


427
@dataclasses.dataclass
428
class ScheduleBatch:
429
430
    """Store all inforamtion of a batch."""

431
    # Request, memory pool, and cache
432
    reqs: List[Req]
433
434
435
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
436
437
438
439

    # For utility
    model_config: ModelConfig = None

Liangsheng Yin's avatar
Liangsheng Yin committed
440
    forward_mode: ForwardMode = None
441
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
442

443
    # Batched arguments to model runner
444
445
446
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
447
    # The output locations of the KV cache
448
    out_cache_loc: torch.Tensor = None
449
450
    output_ids: torch.Tensor = None

451
452
453
    # The sum of all sequence lengths
    seq_lens_sum: int = None

Ke Bao's avatar
Ke Bao committed
454
455
456
    # For DP attention
    global_num_tokens: Optional[List[int]] = None

457
    # For processing logprobs
458
    return_logprob: bool = False
459
460
461
462
463
464
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
465
    decoding_reqs: List[Req] = None
466

467
468
469
470
471
472
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

473
474
475
    # Stream
    has_stream: bool = False

476
477
    # Has grammar
    has_grammar: bool = False
478

479
480
481
    # device
    device: str = "cuda"

482
    @classmethod
483
484
    def init_new(
        cls,
485
        reqs: List[Req],
486
487
488
489
490
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
491
492
493
494
495
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
496
            model_config=model_config,
497
498
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
499
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
500
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
501
502
        )

503
    def batch_size(self):
504
        return len(self.reqs)
505

Lianmin Zheng's avatar
Lianmin Zheng committed
506
507
508
    def is_empty(self):
        return len(self.reqs) == 0

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
527
528
529
530
531
532
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
533
534
535
536
537
538
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
569
                # NOTE: the encoder part should be considered as a whole
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
595
            self.out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
596
597
598
599
600
601
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
602
            self.encoder_out_cache_loc = torch.zeros(0, dtype=torch.int32).to(
603
604
605
606
607
608
609
610
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
611
612
        self.forward_mode = ForwardMode.EXTEND

613
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
614
        reqs = self.reqs
615
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
616
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
617
618
        seq_lens = []

619
        # Allocate memory
620
        req_pool_indices = self.alloc_req_slots(bs)
621
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
622

623
        pt = 0
624
        for i, req in enumerate(reqs):
625
626
627
628
629
630
631
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

632
            req.req_pool_idx = req_pool_indices[i]
633
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
634
            seq_lens.append(seq_len)
635
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
636

637
            if pre_len > 0:
638
639
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
640
                )
641
642
643
            self.req_to_token_pool.write(
                (req.req_pool_idx, slice(pre_len, seq_len)),
                out_cache_loc[pt : pt + req.extend_input_len],
644
            )
645
646
647
648
649
650
651
652
653
654
655

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
656
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658

        # Set fields
659
660
661
662
663
664
665
666
667
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
668

Lianmin Zheng's avatar
Lianmin Zheng committed
669
        self.out_cache_loc = out_cache_loc
670
671

        self.seq_lens_sum = sum(seq_lens)
672
673
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
674
        self.extend_num_tokens = extend_num_tokens
675
676
677
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
678

679
680
681
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

682
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
683
684
685
            self,
            self.model_config.vocab_size,
            global_server_args_dict["disable_penalizer"],
686
        )
687

688
    def mix_with_running(self, running_batch: "ScheduleBatch"):
689
        self.forward_mode = ForwardMode.MIXED
690
        running_bs = running_batch.batch_size()
691
692
693
694
695

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

696
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
697
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
698

699
        self.merge_batch(running_batch)
700
701
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
702
        self.extend_num_tokens += running_bs
703
704

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
705
        self.prefix_lens.extend(
706
707
708
709
710
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
711
712
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
713

714
    def check_decode_mem(self):
715
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
716
        if self.token_to_kv_pool.available_size() >= bs:
717
718
            return True

Mingyi's avatar
Mingyi committed
719
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
720

721
722
723
724
725
726
727
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
728
729

        # TODO(lsyin): improve retraction policy for radix cache
730
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
731
732
733
734
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
735
736
737
738
            reverse=True,
        )

        retracted_reqs = []
739
        seq_lens_cpu = self.seq_lens.cpu().numpy()
740
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
741
742
743
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
744
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
745
746
747
748
749
750
751
752
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

753
            first_iter = False
754
755
756
757
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

758
759
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
760
761
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
762
                ]
763
                self.token_to_kv_pool.free(token_indices)
764
                self.req_to_token_pool.free(req.req_pool_idx)
765
766
767
768
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
769
770
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
771
                ]
772
                self.token_to_kv_pool.free(token_indices)
773
                self.req_to_token_pool.free(req.req_pool_idx)
774
775
776
777
778
779
780
781
782
783
784

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
785

786
            req.prefix_indices = []
787
            req.last_node = None
788
            req.extend_input_len = 0
789
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
790
791
792
793

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
794

795
        self.filter_batch(keep_indices=sorted_indices)
796

Liangsheng Yin's avatar
Liangsheng Yin committed
797
798
799
800
801
802
803
804
805
806
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
807

808
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
809
        jump_forward_reqs = []
810
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
811
812

        for i, req in enumerate(self.reqs):
813
            if req.grammar is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
814
815
816
817
                jump_helper = req.grammar.try_jump_forward(req.tokenizer)
                if jump_helper:
                    suffix_ids, _ = jump_helper

Liangsheng Yin's avatar
Liangsheng Yin committed
818
819
820
821
822
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
823
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
824
825
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
826
827
                        continue

sglang's avatar
sglang committed
828
829
830
                    (
                        jump_forward_str,
                        next_state,
831
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
832

Lianmin Zheng's avatar
Lianmin Zheng committed
833
834
                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
Liangsheng Yin's avatar
Liangsheng Yin committed
835
836
837
838
839
840
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
841

842
843
844
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
845
                    # insert the old request into tree_cache
846
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
847

Liangsheng Yin's avatar
Liangsheng Yin committed
848
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
849
                    if req.image_inputs is not None:
850
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
851
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
852
853
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
854
                    jump_forward_reqs.append(req)
855
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
856

857
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
858

Liangsheng Yin's avatar
Liangsheng Yin committed
859
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
860

861
862
863
864
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

Ke Bao's avatar
Ke Bao committed
865
866
867
868
869
870
871
872
873
874
    def prepare_for_idle(self):
        self.forward_mode = ForwardMode.IDLE
        self.input_ids = torch.empty(0, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.empty(0, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.extend_num_tokens = 0

875
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
876
877
        self.forward_mode = ForwardMode.DECODE

878
879
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
880
881
882
883
        if self.sampling_info.penalizer_orchestrator:
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
884
885

        # Alloc mem
886
        bs = len(self.reqs)
887
        self.out_cache_loc = self.alloc_token_slots(bs)
888

889
890
891
892
893
894
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

895
896
897
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
898
                (self.req_pool_indices, locs), self.out_cache_loc
899
900
901
902
903
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
904
                (self.req_pool_indices, locs), self.out_cache_loc
905
906
            )
            self.seq_lens.add_(1)
907
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
908

909
910
    def filter_batch(
        self,
911
        being_chunked_req: Optional[Req] = None,
912
913
914
915
916
917
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
918
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
919
920
921
            ]

        if keep_indices is None or len(keep_indices) == 0:
922
923
924
925
            # Filter out all requests
            self.reqs = []
            return

926
        if len(keep_indices) == len(self.reqs):
927
928
929
            # No need to filter
            return

930
931
932
933
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

934
        self.reqs = [self.reqs[i] for i in keep_indices]
935
936
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
937
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
938
        self.req_pool_indices = self.req_pool_indices[new_indices]
939
        self.seq_lens = self.seq_lens[new_indices]
940
        self.out_cache_loc = None
941
        self.seq_lens_sum = self.seq_lens.sum().item()
942
        self.output_ids = self.output_ids[new_indices]
943
        self.return_logprob = any(req.return_logprob for req in self.reqs)
944
        if self.return_logprob:
945
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
946
947
        else:
            self.top_logprobs_nums = None
948

949
        self.has_stream = any(req.stream for req in self.reqs)
950
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
951

952
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
953

954
    def merge_batch(self, other: "ScheduleBatch"):
955
956
957
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
958
        self.sampling_info.merge_batch(other.sampling_info)
959

960
961
962
963
964
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
965
966
967
968
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
969
        self.out_cache_loc = None
970
        self.seq_lens_sum += other.seq_lens_sum
971
972
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
973
974
975
976
977
978
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
979
        self.reqs.extend(other.reqs)
980

981
        self.return_logprob = self.return_logprob or other.return_logprob
982
        self.has_stream = self.has_stream or other.has_stream
983
        self.has_grammar = self.has_grammar or other.has_grammar
984
985

    def get_model_worker_batch(self):
Ke Bao's avatar
Ke Bao committed
986
        if self.forward_mode.is_decode() or self.forward_mode.is_idle():
987
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
988
989
990
991
992
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

Ke Bao's avatar
Ke Bao committed
993
994
995
996
997
        if self.sampling_info is not None:
            if self.has_grammar:
                self.sampling_info.grammars = [req.grammar for req in self.reqs]
            else:
                self.sampling_info.grammars = None
998

999
1000
1001
        global bid
        bid += 1

1002
        return ModelWorkerBatch(
1003
            bid=bid,
1004
1005
1006
1007
1008
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1009
            seq_lens_sum=self.seq_lens_sum,
1010
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1011
1012
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
Ke Bao's avatar
Ke Bao committed
1013
            global_num_tokens=self.global_num_tokens,
1014
            extend_num_tokens=self.extend_num_tokens,
1015
1016
1017
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1018
1019
1020
1021
1022
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1023
            lora_paths=[req.lora_path for req in self.reqs],
1024
1025
1026
            sampling_info=self.sampling_info,
        )

1027
    def copy(self):
1028
        # Only contain fields that will be used by process_batch_result
1029
1030
        return ScheduleBatch(
            reqs=self.reqs,
1031
            model_config=self.model_config,
1032
            forward_mode=self.forward_mode,
1033
1034
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1035
            decoding_reqs=self.decoding_reqs,
1036
1037
1038
1039
1040
1041
1042
1043
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1044

1045
@dataclasses.dataclass
1046
class ModelWorkerBatch:
1047
1048
    # The batch id
    bid: int
1049
1050
1051
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1052
    input_ids: torch.Tensor
1053
1054
1055
1056
1057
1058
1059
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1060
1061
1062
    # The sum of all sequence lengths
    seq_lens_sum: int

1063
1064
1065
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1066
1067
1068
1069
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

Ke Bao's avatar
Ke Bao committed
1070
1071
1072
    # For DP attention
    global_num_tokens: Optional[List[int]]

1073
    # For extend
1074
    extend_num_tokens: Optional[int]
1075
1076
1077
1078
1079
1080
1081
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1082
1083
1084
1085
1086
1087
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1088
1089
1090
1091
1092
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1093
1094

    def copy(self):
1095
        return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106

    def to(self, device: str):
        self.input_ids = self.input_ids.to(device, non_blocking=True)
        self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
        self.seq_lens = self.seq_lens.to(device, non_blocking=True)
        self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
        self.req_to_token_pool_records = [
            (x, y.to(device, non_blocking=True))
            for x, y in self.req_to_token_pool_records
        ]
        self.sampling_info.to(device)