schedule_batch.py 25.8 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

18
"""Meta data for requests and batches"""
Lianmin Zheng's avatar
Lianmin Zheng committed
19

Ying Sheng's avatar
Ying Sheng committed
20
import logging
21
from dataclasses import dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
22
from typing import TYPE_CHECKING, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24

import torch
25

Liangsheng Yin's avatar
Liangsheng Yin committed
26
from sglang.global_config import global_config
27
28
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
29
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Liangsheng Yin's avatar
Liangsheng Yin committed
32
from sglang.srt.model_executor.forward_batch_info import ForwardMode
33
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
# Put some global args for easy access
global_server_args_dict = {
40
41
42
43
44
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
    "enable_mla": ServerArgs.enable_mla,
    "torchao_config": ServerArgs.torchao_config,
45
46
}

Lianmin Zheng's avatar
Lianmin Zheng committed
47

Ying Sheng's avatar
Ying Sheng committed
48
49
50
logger = logging.getLogger(__name__)


51
52
53
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
54

55
    def to_json(self):
56
57
58
59
        raise NotImplementedError("Subclasses must implement this method")


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
60
    def __init__(self, matched: Union[int, List[int]]):
61
62
63
        super().__init__()
        self.matched = matched

64
65
66
67
68
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
69
70


71
72
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
73
        super().__init__()
74
        self.matched = matched
75

76
77
78
79
80
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
81
82


83
84
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
85
        super().__init__()
86
        self.length = length
87

88
89
90
91
92
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
93
94
95
96
97
98


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

99
100
101
102
    def to_json(self):
        return {
            "type": "abort",
        }
103

Lianmin Zheng's avatar
Lianmin Zheng committed
104
105

class Req:
106
107
    """Store all inforamtion of a request."""

108
    def __init__(self, rid, origin_input_text, origin_input_ids, lora_path=None):
109
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
110
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
111
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
112
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
113
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
114
        self.output_ids = []  # Each decode stage's output ids
115
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
116
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
117

118
119
120
        # Memory info
        self.req_pool_idx = None

121
        # For incremental decoding
122
123
124
125
126
127
128
129
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
130
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
131
132
133
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
134

135
136
137
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
138

139
        # For vision input
Lianmin Zheng's avatar
Lianmin Zheng committed
140
        self.pixel_values = None
141
142
        self.image_sizes = None
        self.image_offsets = None
143
        self.pad_value = None
144
        self.modalities = None
145

146
147
148
149
150
        # Prefix info
        self.extend_input_len = 0
        self.prefix_indices = []
        self.last_node = None

151
        # Sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
152
153
154
        self.sampling_params = None
        self.stream = False

155
        # Check finish
156
        self.tokenizer = None
157
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
158

159
160
        # Logprobs
        self.return_logprob = False
161
        self.embedding = None
162
163
164
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
        self.normalized_prompt_logprob = None
165
166
167
168
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
169
170
171
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
174
175
176
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
177

178
179
180
181
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

182
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
183
        self.fill_ids = self.origin_input_ids + self.output_ids
184
185
186
187
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
188
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
189

190
    def adjust_max_prefix_ids(self):
191
192
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
193
194
195
196

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
197
198
199
200
201

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

202
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
203
204
205
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
206
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
207

208
        max_prefix_len = max(max_prefix_len, 0)
209
        return self.fill_ids[:max_prefix_len]
210

Liangsheng Yin's avatar
Liangsheng Yin committed
211
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
212
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
213
214
215
216
217
218
219
220
221
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
222
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
223

224
    def get_next_inc_detokenization(self):
225
226
        if self.tokenizer is None:
            return False, ""
227
228
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
229
230
231
232
233

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
234
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
235
236
237
238
239
240
241
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
242
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
243
244

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
245

246
    def check_finished(self):
247
        if self.finished():
248
249
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
250
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
251
252
253
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
254
255
            return

256
        last_token_id = self.output_ids[-1]
257
258
259
260
261
262

        matched_eos = last_token_id in self.sampling_params.stop_token_ids

        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id

263
        if matched_eos and not self.sampling_params.ignore_eos:
264
265
266
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

267
268
269
270
271
272
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
273
                if stop_str in tail_str or stop_str in self.decoded_text:
274
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
275
276
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
277
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
278
279
280
281
282
283
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
284
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
285
        all_ids = self.tokenizer.encode(all_text)
286
        if not all_ids:
havetc's avatar
havetc committed
287
            logger.warning("Encoded all_text resulted in empty all_ids")
288
289
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
290
        prompt_tokens = len(self.origin_input_ids_unpadded)
291
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
292
            logger.warning("prompt_tokens is larger than encoded all_ids")
293
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
294
295
296

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
297
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
316
317
318
319
320
321

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
322
323
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
324
325
326
                    k = k + 1
                else:
                    break
327
328
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
329
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
330
            self.last_update_decode_tokens = len(self.output_ids) - k
331

Liangsheng Yin's avatar
Liangsheng Yin committed
332
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
333

Lianmin Zheng's avatar
Lianmin Zheng committed
334
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
335
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
336
337


338
@dataclass
339
class ScheduleBatch:
340
341
    """Store all inforamtion of a batch."""

342
    # Request, memory pool, and cache
343
344
    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool
345
    token_to_kv_pool: BaseTokenToKVPool
346
    tree_cache: BasePrefixCache
347

Liangsheng Yin's avatar
Liangsheng Yin committed
348
349
    forward_mode: ForwardMode = None

350
    # Batched arguments to model runner
351
352
353
354
355
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
    position_ids_offsets: torch.Tensor = None
    out_cache_loc: torch.Tensor = None
356
    extend_num_tokens: int = None
Liangsheng Yin's avatar
Liangsheng Yin committed
357

358
359
    # For mixed chunekd prefill
    prefix_lens_cpu: List[int] = None
360
    running_bs: int = None
361

362
    # For processing logprobs
363
    return_logprob: bool = False
364
    top_logprobs_nums: List[int] = None
365
366
367

    @classmethod
    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
368
        return_logprob = any(req.return_logprob for req in reqs)
369
370
371
372
373
374

        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
375
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
        )

378
    def batch_size(self):
379
        return len(self.reqs) if self.reqs else 0
380

Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
    def is_empty(self):
        return len(self.reqs) == 0

384
    def has_stream(self) -> bool:
385
        # Return whether batch has at least 1 streaming request
386
387
        return any(r.stream for r in self.reqs)

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
                logger.error("Prefill out of memory. Try to lower your batch size.")
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

413
    def prepare_for_extend(self, vocab_size: int):
Liangsheng Yin's avatar
Liangsheng Yin committed
414
415
        self.forward_mode = ForwardMode.EXTEND

416
        bs = self.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
417
        reqs = self.reqs
418
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
419
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421
        seq_lens = []

422
        # Allocate memory
423
        req_pool_indices_cpu = self.alloc_req_slots(bs)
424
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
425

426
        pt = 0
427
428
        for i, req in enumerate(reqs):
            req.req_pool_idx = req_pool_indices_cpu[i]
429
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
430
431
            ext_len = seq_len - pre_len
            seq_lens.append(seq_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
432

433
            if pre_len > 0:
434
                self.req_to_token_pool.req_to_token[req.req_pool_idx][
435
436
                    :pre_len
                ] = req.prefix_indices
Lianmin Zheng's avatar
Lianmin Zheng committed
437

438
439
440
441
            self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
                out_cache_loc[pt : pt + ext_len]
            )
            pt += ext_len
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443

        # Set fields
444
445
446
447
        with torch.device("cuda"):
            self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
            self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
            self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
448
449
            self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)

Lianmin Zheng's avatar
Lianmin Zheng committed
450
451
        self.extend_num_tokens = extend_num_tokens
        self.out_cache_loc = out_cache_loc
Liangsheng Yin's avatar
Liangsheng Yin committed
452
        self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
453
        self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
454

455
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
456

457
    def mix_with_running(self, running_batch: "ScheduleBatch"):
458
459
460
        self.forward_mode = ForwardMode.MIXED
        self.running_bs = running_batch.batch_size()

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
        prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
        prefix_lens_cpu.extend(
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
        extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
        self.merge(running_batch)
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens_cpu = prefix_lens_cpu

483
    def check_decode_mem(self):
484
        bs = self.batch_size()
Ying Sheng's avatar
Ying Sheng committed
485
        if self.token_to_kv_pool.available_size() >= bs:
486
487
            return True

Mingyi's avatar
Mingyi committed
488
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
489

490
491
492
493
494
495
496
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
497
498

        # TODO(lsyin): improve retraction policy for radix cache
499
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
500
501
502
503
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
504
505
506
507
            reverse=True,
        )

        retracted_reqs = []
508
        seq_lens_cpu = self.seq_lens.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
509
510
511
512
513
514
515
516
517
518
519
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

520
521
522
523
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

524
525
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
526
527
528
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    : seq_lens_cpu[idx]
                ]
529
                self.token_to_kv_pool.free(token_indices)
530
                self.req_to_token_pool.free(req.req_pool_idx)
531
532
533
534
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
535
536
537
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    last_uncached_pos : seq_lens_cpu[idx]
                ]
538
                self.token_to_kv_pool.free(token_indices)
539
                self.req_to_token_pool.free(req.req_pool_idx)
540
541
542
543
544
545
546
547
548
549
550

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
551

552
            req.prefix_indices = []
553
            req.last_node = None
554
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
555
556
557
558

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
559

560
561
        self.filter_batch(sorted_indices)

Liangsheng Yin's avatar
Liangsheng Yin committed
562
563
564
565
566
567
568
569
570
571
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
572

Liangsheng Yin's avatar
Liangsheng Yin committed
573
    def check_for_jump_forward(self, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
574
        jump_forward_reqs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
575
576
577
        filter_indices = [i for i in range(len(self.reqs))]

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
578
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
603
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
604
605
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
606
607
                        continue

sglang's avatar
sglang committed
608
609
610
611
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
612
613
614
615
616
617
618
619
620

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
621

622
623
624
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
625
                    # insert the old request into tree_cache
626
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
627

Liangsheng Yin's avatar
Liangsheng Yin committed
628
629
630
631
                    # re-applying image padding
                    if req.pixel_values is not None:
                        (
                            req.origin_input_ids,
632
                            req.image_offsets,
Liangsheng Yin's avatar
Liangsheng Yin committed
633
634
635
                        ) = model_runner.model.pad_input_ids(
                            req.origin_input_ids_unpadded,
                            req.pad_value,
636
637
                            req.pixel_values,
                            req.image_sizes,
Liangsheng Yin's avatar
Liangsheng Yin committed
638
639
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
640
                    jump_forward_reqs.append(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
641
642
                    filter_indices.remove(i)

643
        self.filter_batch(filter_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
644

Liangsheng Yin's avatar
Liangsheng Yin committed
645
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
646

647
    def prepare_for_decode(self, input_ids=None):
Liangsheng Yin's avatar
Liangsheng Yin committed
648
649
        self.forward_mode = ForwardMode.DECODE

Lianmin Zheng's avatar
Lianmin Zheng committed
650
651
        if input_ids is None:
            input_ids = [
652
653
                r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
                for r in self.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
654
            ]
655
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
656
            self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
657

Lianmin Zheng's avatar
Lianmin Zheng committed
658
659
660
661
        self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
        self.seq_lens.add_(1)

        # Alloc mem
662
663
        bs = self.batch_size()
        self.out_cache_loc = self.alloc_token_slots(bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
667
668
669

        self.req_to_token_pool.req_to_token[
            self.req_pool_indices, self.seq_lens - 1
        ] = self.out_cache_loc

    def filter_batch(self, unfinished_indices: List[int]):
670
671
672
673
674
675
676
677
678
        if unfinished_indices is None or len(unfinished_indices) == 0:
            # Filter out all requests
            self.reqs = []
            return

        if len(unfinished_indices) == len(self.reqs):
            # No need to filter
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
679
680
681
682
683
684
        self.reqs = [self.reqs[i] for i in unfinished_indices]
        new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
        self.seq_lens = self.seq_lens[new_indices]
        self.input_ids = None
        self.req_pool_indices = self.req_pool_indices[new_indices]
        self.position_ids_offsets = self.position_ids_offsets[new_indices]
685
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
686
        self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
687
        self.return_logprob = any(req.return_logprob for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
688

689
        self.sampling_info.filter(unfinished_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
690

691
    def merge(self, other: "ScheduleBatch"):
692
693
694
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
695
        self.sampling_info.merge(other.sampling_info)
696

Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
699
700
701
702
703
704
705
        self.reqs.extend(other.reqs)

        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
        self.position_ids_offsets = torch.concat(
            [self.position_ids_offsets, other.position_ids_offsets]
        )
706
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
707
        self.top_logprobs_nums.extend(other.top_logprobs_nums)
708
        self.return_logprob = any(req.return_logprob for req in self.reqs)