schedule_batch.py 39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37

Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.global_config import global_config
39
from sglang.srt.configs.model_config import ModelConfig
40
41
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
42
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
45
from sglang.srt.model_executor.forward_batch_info import ForwardMode
46
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
47
from sglang.srt.sampling.sampling_params import SamplingParams
48
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
49
50

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
51

52
53
# Put some global args for easy access
global_server_args_dict = {
54
55
56
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
57
    "disable_mla": ServerArgs.disable_mla,
58
    "torchao_config": ServerArgs.torchao_config,
59
    "disable_nan_detection": ServerArgs.disable_nan_detection,
60
61
}

Lianmin Zheng's avatar
Lianmin Zheng committed
62

Ying Sheng's avatar
Ying Sheng committed
63
64
65
logger = logging.getLogger(__name__)


66
67
68
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
69

70
    def to_json(self):
71
        raise NotImplementedError()
72
73
74


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
75
    def __init__(self, matched: Union[int, List[int]]):
76
77
78
        super().__init__()
        self.matched = matched

79
80
81
82
83
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
84
85


86
87
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
88
        super().__init__()
89
        self.matched = matched
90

91
92
93
94
95
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
96
97


98
99
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
100
        super().__init__()
101
        self.length = length
102

103
104
105
106
107
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
108
109
110
111
112
113


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

114
115
116
117
    def to_json(self):
        return {
            "type": "abort",
        }
118

Lianmin Zheng's avatar
Lianmin Zheng committed
119

120
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
121
class ImageInputs:
122
123
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
124
    pixel_values: torch.Tensor
125
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
126
127
128
129
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
130
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
131
132
133
134

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
Yineng Zhang's avatar
Yineng Zhang committed
135
136
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
137
138
139
140
141
142

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
143
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
144
        )
145
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
146
147
148
149
150
151
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
152
153
154
155
156
157
158
159
160
161
162
163

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
164
165
166
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
167
class Req:
168
    """The input and output status of a request."""
169

170
171
172
173
174
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
175
        sampling_params: SamplingParams,
176
177
        lora_path: Optional[str] = None,
    ):
178
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
179
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
180
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
181
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
182
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
183
        self.output_ids = []  # Each decode stage's output ids
184
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
185
186

        self.sampling_params = sampling_params
187
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
188

189
190
191
        # Memory info
        self.req_pool_idx = None

192
193
194
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
195
        self.stream = False
196

197
        # For incremental decoding
198
199
200
201
202
203
204
205
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
206
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
207
208
209
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
210

211
212
213
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
214

215
216
217
        # The number of cached tokens, that were already cached in the KV store
        self.cached_tokens = 0

218
        # For vision inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
219
        self.image_inputs: Optional[ImageInputs] = None
220

221
222
        # Prefix info
        self.prefix_indices = []
223
        self.extend_input_len = 0
224
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
225
        self.is_inflight_req = 0
226

227
        # Logprobs (arguments)
228
229
230
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
231
232

        # Logprobs (return value)
233
        self.normalized_prompt_logprob = None
234
235
236
237
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
238
239

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
240
241
242
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
243
244
245
246
247
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

        # Embedding
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
248

249
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
250
251
252
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
253

Yineng Zhang's avatar
Yineng Zhang committed
254
255
256
        # For Qwen2-VL
        self.mrope_position_delta = []  # use mutable object

257
258
259
260
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

261
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
262
        self.fill_ids = self.origin_input_ids + self.output_ids
263
264
265
266
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
267
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
268

269
    def adjust_max_prefix_ids(self):
270
271
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
272
273
274
275

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
280

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

281
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
285
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
286

287
        max_prefix_len = max(max_prefix_len, 0)
288
        return self.fill_ids[:max_prefix_len]
289

Liangsheng Yin's avatar
Liangsheng Yin committed
290
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
291
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
296
297
298
299
300
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
301
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
302

303
    def get_next_inc_detokenization(self):
304
305
        if self.tokenizer is None:
            return False, ""
306
307
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
308
309
310
311
312

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
313
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318
319
320
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
321
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
322
323

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
324

325
    def check_finished(self):
326
        if self.finished():
327
328
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
329
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
330
331
332
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
333
334
            return

335
        last_token_id = self.output_ids[-1]
336
337
338
339
340
341

        matched_eos = last_token_id in self.sampling_params.stop_token_ids

        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id

342
        if matched_eos and not self.sampling_params.ignore_eos:
343
344
345
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

346
347
348
349
350
351
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
352
                if stop_str in tail_str or stop_str in self.decoded_text:
353
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
354
355
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
356
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
357
358
359
360
361
362
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
363
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
364
        all_ids = self.tokenizer.encode(all_text)
365
        if not all_ids:
havetc's avatar
havetc committed
366
            logger.warning("Encoded all_text resulted in empty all_ids")
367
368
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
369
        prompt_tokens = len(self.origin_input_ids_unpadded)
370
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
371
            logger.warning("prompt_tokens is larger than encoded all_ids")
372
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
373
374
375

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
376
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
395
396
397
398
399
400

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
401
402
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
403
404
405
                    k = k + 1
                else:
                    break
406
407
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
408
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
409
            self.last_update_decode_tokens = len(self.output_ids) - k
410

Liangsheng Yin's avatar
Liangsheng Yin committed
411
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
412

Lianmin Zheng's avatar
Lianmin Zheng committed
413
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
414
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416


417
418
419
bid = 0


420
@dataclasses.dataclass
421
class ScheduleBatch:
422
423
    """Store all inforamtion of a batch."""

424
    # Request, memory pool, and cache
425
    reqs: List[Req]
426
427
428
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
429
430
431
432

    # For utility
    model_config: ModelConfig = None

Liangsheng Yin's avatar
Liangsheng Yin committed
433
    forward_mode: ForwardMode = None
434
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
435

436
    # Batched arguments to model runner
437
438
439
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
440
    # The output locations of the KV cache
441
    out_cache_loc: torch.Tensor = None
442
443
    output_ids: torch.Tensor = None

444
445
446
    # The sum of all sequence lengths
    seq_lens_sum: int = None

447
    # For processing logprobs
448
    return_logprob: bool = False
449
450
451
452
453
454
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
455
    decoding_reqs: List[Req] = None
456

457
458
459
460
461
462
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

463
464
465
    # Stream
    has_stream: bool = False

466
467
468
    # Has regex
    has_regex: bool = False

469
470
471
    # device
    device: str = "cuda"

472
    @classmethod
473
474
475
476
477
478
479
480
    def init_new(
        cls,
        reqs,
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
481
482
483
484
485
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
486
            model_config=model_config,
487
488
489
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
            has_regex=any(req.regex_fsm for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
490
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
491
492
        )

493
    def batch_size(self):
494
        return len(self.reqs)
495

Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
498
    def is_empty(self):
        return len(self.reqs) == 0

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
                logger.error("Prefill out of memory. Try to lower your batch size.")
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
                # NOTE: the encoder part should considered as a whole
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
            self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
            self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
596
597
        self.forward_mode = ForwardMode.EXTEND

598
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
599
        reqs = self.reqs
600
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
601
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
602
603
        seq_lens = []

604
        # Allocate memory
605
        req_pool_indices = self.alloc_req_slots(bs)
606
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
607

608
        pt = 0
609
        for i, req in enumerate(reqs):
610
611
612
613
614
615
616
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

617
            req.req_pool_idx = req_pool_indices[i]
618
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
619
            seq_lens.append(seq_len)
620
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
621

622
            if pre_len > 0:
623
624
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
625
                )
626
627
628
            self.req_to_token_pool.write(
                (req.req_pool_idx, slice(pre_len, seq_len)),
                out_cache_loc[pt : pt + req.extend_input_len],
629
            )
630
631
632
633
634
635
636
637
638
639
640

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
641
642

        # Set fields
643
644
645
646
647
648
649
650
651
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
652

Lianmin Zheng's avatar
Lianmin Zheng committed
653
        self.out_cache_loc = out_cache_loc
654
655

        self.seq_lens_sum = sum(seq_lens)
656
657
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
658
        self.extend_num_tokens = extend_num_tokens
659
660
661
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
662

663
664
665
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

666
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
667
668
669
            self,
            self.model_config.vocab_size,
            global_server_args_dict["disable_penalizer"],
670
        )
671

672
    def mix_with_running(self, running_batch: "ScheduleBatch"):
673
        self.forward_mode = ForwardMode.MIXED
674
        running_bs = running_batch.batch_size()
675
676
677
678
679

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

680
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
681
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
682

683
        self.merge_batch(running_batch)
684
685
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
686
        self.extend_num_tokens += running_bs
687
688

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
689
        self.prefix_lens.extend(
690
691
692
693
694
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
695
696
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
697

698
    def check_decode_mem(self):
699
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
700
        if self.token_to_kv_pool.available_size() >= bs:
701
702
            return True

Mingyi's avatar
Mingyi committed
703
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
704

705
706
707
708
709
710
711
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
712
713

        # TODO(lsyin): improve retraction policy for radix cache
714
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
715
716
717
718
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
719
720
721
722
            reverse=True,
        )

        retracted_reqs = []
723
        seq_lens_cpu = self.seq_lens.cpu().numpy()
724
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
725
726
727
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
728
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
729
730
731
732
733
734
735
736
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

737
            first_iter = False
738
739
740
741
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

742
743
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
744
745
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
746
                ]
747
                self.token_to_kv_pool.free(token_indices)
748
                self.req_to_token_pool.free(req.req_pool_idx)
749
750
751
752
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
753
754
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
755
                ]
756
                self.token_to_kv_pool.free(token_indices)
757
                self.req_to_token_pool.free(req.req_pool_idx)
758
759
760
761
762
763
764
765
766
767
768

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
769

770
            req.prefix_indices = []
771
            req.last_node = None
772
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
773
774
775
776

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
777

778
        self.filter_batch(keep_indices=sorted_indices)
779

Liangsheng Yin's avatar
Liangsheng Yin committed
780
781
782
783
784
785
786
787
788
789
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
790

791
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
792
        jump_forward_reqs = []
793
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
794
795

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
796
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
821
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
822
823
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
824
825
                        continue

sglang's avatar
sglang committed
826
827
828
829
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
830
831
832
833
834
835
836
837
838

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
839

840
841
842
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
843
                    # insert the old request into tree_cache
844
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
845

Liangsheng Yin's avatar
Liangsheng Yin committed
846
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
847
                    if req.image_inputs is not None:
848
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
849
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
850
851
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
852
                    jump_forward_reqs.append(req)
853
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
854

855
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
856

Liangsheng Yin's avatar
Liangsheng Yin committed
857
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
858

859
860
861
862
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

863
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
864
865
        self.forward_mode = ForwardMode.DECODE

866
867
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
868
869
870
871
        if self.sampling_info.penalizer_orchestrator:
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
872
873

        # Alloc mem
874
        bs = len(self.reqs)
875
        self.out_cache_loc = self.alloc_token_slots(bs)
876

877
878
879
880
881
882
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

883
884
885
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
886
                (self.req_pool_indices, locs), self.out_cache_loc
887
888
889
890
891
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
892
                (self.req_pool_indices, locs), self.out_cache_loc
893
894
            )
            self.seq_lens.add_(1)
895
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
896

897
898
899
900
901
902
903
904
905
906
907
908
909
910
    def filter_batch(
        self,
        current_inflight_req: Optional[Req] = None,
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
                if not self.reqs[i].finished()
                and self.reqs[i] is not current_inflight_req
            ]

        if keep_indices is None or len(keep_indices) == 0:
911
912
913
914
            # Filter out all requests
            self.reqs = []
            return

915
        if len(keep_indices) == len(self.reqs):
916
917
918
            # No need to filter
            return

919
920
921
922
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

923
        self.reqs = [self.reqs[i] for i in keep_indices]
924
925
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
926
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
927
        self.req_pool_indices = self.req_pool_indices[new_indices]
928
        self.seq_lens = self.seq_lens[new_indices]
929
        self.out_cache_loc = None
930
        self.seq_lens_sum = self.seq_lens.sum().item()
931
        self.output_ids = self.output_ids[new_indices]
932
        self.return_logprob = any(req.return_logprob for req in self.reqs)
933
        if self.return_logprob:
934
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
935
936
        else:
            self.top_logprobs_nums = None
937

938
        self.has_stream = any(req.stream for req in self.reqs)
939
        self.has_regex = any(req.regex_fsm for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
940

941
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
942

943
    def merge_batch(self, other: "ScheduleBatch"):
944
945
946
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
947
        self.sampling_info.merge_batch(other.sampling_info)
948

949
950
951
952
953
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
954
955
956
957
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
958
        self.out_cache_loc = None
959
        self.seq_lens_sum += other.seq_lens_sum
960
961
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
962
963
964
965
966
967
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
968
        self.reqs.extend(other.reqs)
969

970
        self.return_logprob = self.return_logprob or other.return_logprob
971
972
        self.has_stream = self.has_stream or other.has_stream
        self.has_regex = self.has_regex or other.has_regex
973
974
975

    def get_model_worker_batch(self):
        if self.forward_mode.is_decode():
976
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
977
978
979
980
981
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

982
983
984
985
986
        if self.has_regex:
            self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
            self.sampling_info.regex_fsm_states = [
                req.regex_fsm_state for req in self.reqs
            ]
987
988
        else:
            self.sampling_info.regex_fsms = None
989

990
991
992
        global bid
        bid += 1

Yineng Zhang's avatar
Yineng Zhang committed
993
994
        mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]

995
        return ModelWorkerBatch(
996
            bid=bid,
997
998
999
1000
1001
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
1002
            seq_lens_sum=self.seq_lens_sum,
1003
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1004
1005
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
1006
            extend_num_tokens=self.extend_num_tokens,
1007
1008
1009
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1010
1011
1012
1013
1014
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1015
            lora_paths=[req.lora_path for req in self.reqs],
1016
            sampling_info=self.sampling_info,
Yineng Zhang's avatar
Yineng Zhang committed
1017
            mrope_positions_delta=mrope_positions_delta,
1018
1019
        )

1020
    def copy(self):
1021
        # Only contain fields that will be used by process_batch_result
1022
1023
        return ScheduleBatch(
            reqs=self.reqs,
1024
            model_config=self.model_config,
1025
            forward_mode=self.forward_mode,
1026
1027
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1028
            decoding_reqs=self.decoding_reqs,
1029
1030
1031
1032
1033
1034
1035
1036
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

1037

1038
@dataclasses.dataclass
1039
class ModelWorkerBatch:
1040
1041
    # The batch id
    bid: int
1042
1043
1044
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1045
    input_ids: torch.Tensor
1046
1047
1048
1049
1050
1051
1052
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1053
1054
1055
    # The sum of all sequence lengths
    seq_lens_sum: int

1056
1057
1058
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1059
1060
1061
1062
1063
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

    # For extend
1064
    extend_num_tokens: Optional[int]
1065
1066
1067
1068
1069
1070
1071
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1072
1073
1074
1075
1076
1077
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1078
1079
1080
1081
1082
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1083

Yineng Zhang's avatar
Yineng Zhang committed
1084
1085
1086
    # For Qwen2-VL
    mrope_positions_delta: List[List[int]]

1087
    def copy(self):
1088
        return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099

    def to(self, device: str):
        self.input_ids = self.input_ids.to(device, non_blocking=True)
        self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
        self.seq_lens = self.seq_lens.to(device, non_blocking=True)
        self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
        self.req_to_token_pool_records = [
            (x, y.to(device, non_blocking=True))
            for x, y in self.req_to_token_pool_records
        ]
        self.sampling_info.to(device)