schedule_batch.py 25.3 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

18
"""Meta data for requests and batches"""
Lianmin Zheng's avatar
Lianmin Zheng committed
19

Ying Sheng's avatar
Ying Sheng committed
20
import logging
21
from dataclasses import dataclass
22
from typing import TYPE_CHECKING, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24

import torch
25

Liangsheng Yin's avatar
Liangsheng Yin committed
26
from sglang.global_config import global_config
27
28
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
29
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
Liangsheng Yin's avatar
Liangsheng Yin committed
33

34
35
36
37
if TYPE_CHECKING:
    from sglang.srt.layers.sampler import SampleOutput


Liangsheng Yin's avatar
Liangsheng Yin committed
38
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
41
42
43
# Put some global args for easy access
global_server_args_dict = {
    "disable_flashinfer": False,
    "disable_flashinfer_sampling": False,
44
    "triton_attention_reduce_in_fp32": False,
45
    "enable_mla": False,
46
47
}

Lianmin Zheng's avatar
Lianmin Zheng committed
48

Ying Sheng's avatar
Ying Sheng committed
49
50
51
logger = logging.getLogger(__name__)


52
53
54
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
55

56
57
58
59
60
    def __str__(self):
        raise NotImplementedError("Subclasses must implement this method")


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
61
    def __init__(self, matched: Union[int, List[int]]):
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        super().__init__()
        self.matched = matched

    def __str__(self) -> str:
        return f"FINISH_MATCHED_TOKEN: {self.matched}"


class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
        super().__init__()
        self.length = length

    def __str__(self) -> str:
        return f"FINISH_LENGTH: {self.length}"


class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
        super().__init__()
        self.matched = matched

    def __str__(self) -> str:
        return f"FINISH_MATCHED_STR: {self.matched}"


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

    def __str__(self) -> str:
        return "FINISH_ABORT"
93

Lianmin Zheng's avatar
Lianmin Zheng committed
94
95

class Req:
96
97
    """Store all inforamtion of a request."""

Liangsheng Yin's avatar
Liangsheng Yin committed
98
    def __init__(self, rid, origin_input_text, origin_input_ids):
99
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
101
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
102
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
103
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
104
        self.output_ids = []  # Each decode stage's output ids
105
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
106

107
108
109
        # Memory info
        self.req_pool_idx = None

110
        # For incremental decoding
111
112
113
114
115
116
117
118
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
119
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
123

124
125
126
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
127

128
        # For vision input
Lianmin Zheng's avatar
Lianmin Zheng committed
129
        self.pixel_values = None
shiyi.c_98's avatar
shiyi.c_98 committed
130
        self.image_size = None
131
        self.image_offset = None
132
        self.pad_value = None
133

134
135
136
137
138
        # Prefix info
        self.extend_input_len = 0
        self.prefix_indices = []
        self.last_node = None

139
        # Sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141
142
        self.sampling_params = None
        self.stream = False

143
        # Check finish
144
        self.tokenizer = None
145
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
148
        # Logprobs
        self.return_logprob = False
149
        self.embedding = None
150
151
152
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
        self.normalized_prompt_logprob = None
153
154
155
156
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
157
158
159
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
160

161
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
162
163
164
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
165

166
167
168
169
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

170
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
171
        self.fill_ids = self.origin_input_ids + self.output_ids
172
173
174
175
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
176
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
177

178
    def adjust_max_prefix_ids(self):
179
180
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
181
182
183
184
185
186
        max_prefix_len = input_len

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

187
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
188
189
190
191
192
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)

            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
193

194
        return self.fill_ids[:max_prefix_len]
195

Liangsheng Yin's avatar
Liangsheng Yin committed
196
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
197
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
198
199
200
201
202
203
204
205
206
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
207
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
208

209
    def get_next_inc_detokenization(self):
210
211
        if self.tokenizer is None:
            return False, ""
212
213
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
214
215
216
217
218

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
219
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
220
221
222
223
224
225
226
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
227
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
228
229

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
230

231
    def check_finished(self):
232
        if self.finished():
233
234
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
235
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
236
237
238
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
239
240
            return

241
        last_token_id = self.output_ids[-1]
242
243
244
245
246
247

        matched_eos = last_token_id in self.sampling_params.stop_token_ids

        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id

248
        if matched_eos and not self.sampling_params.ignore_eos:
249
250
251
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

252
253
254
255
256
257
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
258
                if stop_str in tail_str or stop_str in self.decoded_text:
259
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
260
261
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
262
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
263
264
265
266
267
268
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
269
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
270
271
        all_ids = self.tokenizer.encode(all_text)
        prompt_tokens = len(self.origin_input_ids_unpadded)
Liangsheng Yin's avatar
Liangsheng Yin committed
272
273
274

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
275
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
294
295
296
297
298
299

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
300
301
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
302
303
304
                    k = k + 1
                else:
                    break
305
306
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
307
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
308
            self.last_update_decode_tokens = len(self.output_ids) - k
309

Liangsheng Yin's avatar
Liangsheng Yin committed
310
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
311

Lianmin Zheng's avatar
Lianmin Zheng committed
312
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
313
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
314
315


316
@dataclass
317
class ScheduleBatch:
318
319
    """Store all inforamtion of a batch."""

320
    # Request, memory pool, and cache
321
322
    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool
323
    token_to_kv_pool: BaseTokenToKVPool
324
    tree_cache: BasePrefixCache
325

326
    # Batched arguments to model runner
327
328
329
330
331
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
    position_ids_offsets: torch.Tensor = None
    out_cache_loc: torch.Tensor = None
332
    extend_num_tokens: int = None
Liangsheng Yin's avatar
Liangsheng Yin committed
333

334
335
336
    # For mixed chunekd prefill
    prefix_lens_cpu: List[int] = None

337
    # For processing logprobs
338
    return_logprob: bool = False
339
    top_logprobs_nums: List[int] = None
340
341
342

    @classmethod
    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
343
        return_logprob = any(req.return_logprob for req in reqs)
344
345
346
347
348
349

        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
350
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
        )

353
354
355
    def batch_size(self):
        return len(self.reqs) if self.reqs is not None else 0

Lianmin Zheng's avatar
Lianmin Zheng committed
356
357
358
    def is_empty(self):
        return len(self.reqs) == 0

359
    def has_stream(self) -> bool:
360
        # Return whether batch has at least 1 streaming request
361
362
        return any(r.stream for r in self.reqs)

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
                logger.error("Prefill out of memory. Try to lower your batch size.")
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

388
    def prepare_for_extend(self, vocab_size: int):
389
        bs = self.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
390
        reqs = self.reqs
391
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
392
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
        seq_lens = []

395
        # Allocate memory
396
        req_pool_indices_cpu = self.alloc_req_slots(bs)
397
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
398

399
        pt = 0
400
401
        for i, req in enumerate(reqs):
            req.req_pool_idx = req_pool_indices_cpu[i]
402
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
403
404
            ext_len = seq_len - pre_len
            seq_lens.append(seq_len)
Lianmin Zheng's avatar
Lianmin Zheng committed
405

406
            if pre_len > 0:
407
                self.req_to_token_pool.req_to_token[req.req_pool_idx][
408
409
                    :pre_len
                ] = req.prefix_indices
Lianmin Zheng's avatar
Lianmin Zheng committed
410

411
412
413
414
            self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
                out_cache_loc[pt : pt + ext_len]
            )
            pt += ext_len
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416

        # Set fields
417
418
419
420
        with torch.device("cuda"):
            self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
            self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
            self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
421
422
            self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)

Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
        self.extend_num_tokens = extend_num_tokens
        self.out_cache_loc = out_cache_loc
Liangsheng Yin's avatar
Liangsheng Yin committed
425
        self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
426
        self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
427

428
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
429

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    def mix_with_running(self, running_batch: "ScheduleBatch"):
        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
        prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
        prefix_lens_cpu.extend(
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
        extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
        self.merge(running_batch)
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens_cpu = prefix_lens_cpu

453
    def check_decode_mem(self):
454
        bs = self.batch_size()
Ying Sheng's avatar
Ying Sheng committed
455
        if self.token_to_kv_pool.available_size() >= bs:
456
457
            return True

Mingyi's avatar
Mingyi committed
458
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
459

460
461
462
463
464
465
466
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
467
468

        # TODO(lsyin): improve retraction policy for radix cache
469
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
470
471
472
473
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
474
475
476
477
            reverse=True,
        )

        retracted_reqs = []
478
        seq_lens_cpu = self.seq_lens.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
479
480
481
482
483
484
485
486
487
488
489
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

490
491
492
493
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

494
495
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
496
497
498
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    : seq_lens_cpu[idx]
                ]
499
                self.token_to_kv_pool.free(token_indices)
500
                self.req_to_token_pool.free(req.req_pool_idx)
501
502
503
504
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
505
506
507
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    last_uncached_pos : seq_lens_cpu[idx]
                ]
508
                self.token_to_kv_pool.free(token_indices)
509
                self.req_to_token_pool.free(req.req_pool_idx)
510
511
512
513
514
515
516
517
518
519
520

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
521

522
            req.prefix_indices = []
523
            req.last_node = None
524
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
525
526
527
528

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
529

530
531
        self.filter_batch(sorted_indices)

Liangsheng Yin's avatar
Liangsheng Yin committed
532
533
534
535
536
537
538
539
540
541
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
542

Liangsheng Yin's avatar
Liangsheng Yin committed
543
    def check_for_jump_forward(self, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
544
        jump_forward_reqs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
545
546
547
        filter_indices = [i for i in range(len(self.reqs))]

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
548
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
573
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
574
575
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
576
577
                        continue

sglang's avatar
sglang committed
578
579
580
581
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
582
583
584
585
586
587
588
589
590

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
591

592
593
594
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
595
                    # insert the old request into tree_cache
596
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
597

Liangsheng Yin's avatar
Liangsheng Yin committed
598
599
600
601
602
603
604
605
606
607
608
609
                    # re-applying image padding
                    if req.pixel_values is not None:
                        (
                            req.origin_input_ids,
                            req.image_offset,
                        ) = model_runner.model.pad_input_ids(
                            req.origin_input_ids_unpadded,
                            req.pad_value,
                            req.pixel_values.shape,
                            req.image_size,
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
610
                    jump_forward_reqs.append(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
611
612
                    filter_indices.remove(i)

613
        self.filter_batch(filter_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
614

Liangsheng Yin's avatar
Liangsheng Yin committed
615
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
616

617
    def prepare_for_decode(self, input_ids=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
618
619
        if input_ids is None:
            input_ids = [
620
621
                r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
                for r in self.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
622
            ]
623
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
624
            self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
625

Lianmin Zheng's avatar
Lianmin Zheng committed
626
627
628
629
        self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
        self.seq_lens.add_(1)

        # Alloc mem
630
631
        bs = self.batch_size()
        self.out_cache_loc = self.alloc_token_slots(bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
632
633
634
635
636

        self.req_to_token_pool.req_to_token[
            self.req_pool_indices, self.seq_lens - 1
        ] = self.out_cache_loc

637
638
        self.sampling_info.update_regex_vocab_mask(self)

Lianmin Zheng's avatar
Lianmin Zheng committed
639
    def filter_batch(self, unfinished_indices: List[int]):
640
641
642
643
644
645
646
647
648
        if unfinished_indices is None or len(unfinished_indices) == 0:
            # Filter out all requests
            self.reqs = []
            return

        if len(unfinished_indices) == len(self.reqs):
            # No need to filter
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
649
650
651
652
653
654
        self.reqs = [self.reqs[i] for i in unfinished_indices]
        new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
        self.seq_lens = self.seq_lens[new_indices]
        self.input_ids = None
        self.req_pool_indices = self.req_pool_indices[new_indices]
        self.position_ids_offsets = self.position_ids_offsets[new_indices]
655
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
656
        self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
657
        self.return_logprob = any(req.return_logprob for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
658

659
        self.sampling_info.filter(unfinished_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
660

661
    def merge(self, other: "ScheduleBatch"):
662
663
664
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
665
        self.sampling_info.merge(other.sampling_info)
666

Lianmin Zheng's avatar
Lianmin Zheng committed
667
668
669
670
671
672
673
674
675
        self.reqs.extend(other.reqs)

        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
        self.position_ids_offsets = torch.concat(
            [self.position_ids_offsets, other.position_ids_offsets]
        )
676
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
677
        self.top_logprobs_nums.extend(other.top_logprobs_nums)
678
        self.return_logprob = any(req.return_logprob for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
679

680
681
682
683
684
685
686
687
688
689
690
691
    def check_sample_results(self, sample_output: SampleOutput):
        if not torch.all(sample_output.success):
            probs = sample_output.probs
            batch_next_token_ids = sample_output.batch_next_token_ids
            logging.warning("Sampling failed, fallback to top_k=1 strategy")
            probs = probs.masked_fill(torch.isnan(probs), 0.0)
            argmax_ids = torch.argmax(probs, dim=-1)
            batch_next_token_ids = torch.where(
                sample_output.success, batch_next_token_ids, argmax_ids
            )
            sample_output.probs = probs
            sample_output.batch_next_token_ids = batch_next_token_ids
692

693
        return sample_output.batch_next_token_ids