schedule_batch.py 35.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
"""Meta data for requests and batches"""
Lianmin Zheng's avatar
Lianmin Zheng committed
17

Ying Sheng's avatar
Ying Sheng committed
18
import logging
19
import warnings
20
from dataclasses import dataclass
21
from enum import IntEnum, auto
Mingyi's avatar
Mingyi committed
22
from typing import List, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
25

import numpy as np
import torch
26
from flashinfer.sampling import top_k_top_p_sampling_from_probs
Liangsheng Yin's avatar
Liangsheng Yin committed
27

Liangsheng Yin's avatar
Liangsheng Yin committed
28
from sglang.global_config import global_config
29
30
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
31
32
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache
Liangsheng Yin's avatar
Liangsheng Yin committed
33
34

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
37
38
39
40
41
42
# Put some global args for easy access
global_server_args_dict = {
    "disable_flashinfer": False,
    "disable_flashinfer_sampling": False,
    "attention_reduce_in_fp32": False,
}

Lianmin Zheng's avatar
Lianmin Zheng committed
43

Ying Sheng's avatar
Ying Sheng committed
44
45
46
logger = logging.getLogger(__name__)


47
class ForwardMode(IntEnum):
48
    # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
Lianmin Zheng's avatar
Lianmin Zheng committed
49
    PREFILL = auto()
50
    # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    EXTEND = auto()
52
    # Decode one token.
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
    DECODE = auto()

55

56
57
58
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60
61
62
63
64
    def __str__(self):
        raise NotImplementedError("Subclasses must implement this method")


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
65
    def __init__(self, matched: Union[int, List[int]]):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        super().__init__()
        self.matched = matched

    def __str__(self) -> str:
        return f"FINISH_MATCHED_TOKEN: {self.matched}"


class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
        super().__init__()
        self.length = length

    def __str__(self) -> str:
        return f"FINISH_LENGTH: {self.length}"


class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
        super().__init__()
        self.matched = matched

    def __str__(self) -> str:
        return f"FINISH_MATCHED_STR: {self.matched}"


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

    def __str__(self) -> str:
        return "FINISH_ABORT"
97

Lianmin Zheng's avatar
Lianmin Zheng committed
98
99

class Req:
100
101
    """Store all inforamtion of a request."""

Liangsheng Yin's avatar
Liangsheng Yin committed
102
    def __init__(self, rid, origin_input_text, origin_input_ids):
103
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
104
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
105
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
106
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
107
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
108
109
110
        self.output_ids = []  # Each decode stage's output ids
        self.input_ids = None  # input_ids = origin_input_ids + output_ids

111
        # For incremental decoding
112
113
114
115
116
117
118
119
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
120
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
121
122
123
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
124

125
126
127
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
128

129
        # For vision input
Lianmin Zheng's avatar
Lianmin Zheng committed
130
        self.pixel_values = None
shiyi.c_98's avatar
shiyi.c_98 committed
131
        self.image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
132
        self.image_offset = 0
133
        self.pad_value = None
134

135
136
137
138
139
        # Prefix info
        self.extend_input_len = 0
        self.prefix_indices = []
        self.last_node = None

140
        # Sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
143
        self.sampling_params = None
        self.stream = False

144
        # Check finish
145
        self.tokenizer = None
146
        self.finished_reason = None
Lianmin Zheng's avatar
Lianmin Zheng committed
147

148
149
150
151
152
        # Logprobs
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
        self.normalized_prompt_logprob = None
153
154
155
156
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
157
158
159
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
160

161
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
162
163
164
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
165

166
167
168
169
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

Liangsheng Yin's avatar
Liangsheng Yin committed
170
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
171
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
172
173
174
175
176
177
178
179
180
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
181
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
182

183
184
185
    def get_next_inc_detokenization(self):
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
186
187
188
189
190

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
191
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
192
193
194
195
196
197
198
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
199
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
200
201

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
202

203
    def check_finished(self):
204
        if self.finished():
205
206
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
207
208
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
            self.finished_reason = FINISH_LENGTH(len(self.output_ids))
209
210
211
212
            return

        if (
            self.output_ids[-1] == self.tokenizer.eos_token_id
213
            and not self.sampling_params.ignore_eos
214
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
            self.finished_reason = FINISH_MATCHED_TOKEN(
                matched=self.tokenizer.eos_token_id
            )
218
219
220
221
222
223
224
225
            return

        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
226
                if stop_str in tail_str or stop_str in self.decoded_text:
227
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
228
229
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
230
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
233
234
235
236
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
237
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
238
239
        all_ids = self.tokenizer.encode(all_text)
        prompt_tokens = len(self.origin_input_ids_unpadded)
Liangsheng Yin's avatar
Liangsheng Yin committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
            warnings.warn(
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
262
263
264
265
266
267

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
268
269
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
270
271
272
                    k = k + 1
                else:
                    break
273
274
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
275
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
276
            self.last_update_decode_tokens = len(self.output_ids) - k
277

Liangsheng Yin's avatar
Liangsheng Yin committed
278
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
279

Lianmin Zheng's avatar
Lianmin Zheng committed
280
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
281
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283


284
@dataclass
Lianmin Zheng's avatar
Lianmin Zheng committed
285
class Batch:
286
287
    """Store all inforamtion of a batch."""

288
    # Request, memory pool, and cache
289
290
291
292
293
    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool
    tree_cache: RadixCache

294
    # Batched arguments to model runner
295
296
297
298
299
300
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
    prefix_lens: torch.Tensor = None
    position_ids_offsets: torch.Tensor = None
    out_cache_loc: torch.Tensor = None
301
    extend_num_tokens: int = None
Liangsheng Yin's avatar
Liangsheng Yin committed
302

303
    # For processing logprobs
304
    return_logprob: bool = False
305
    top_logprobs_nums: List[int] = None
306

307
    # For multimodal
308
    pixel_values: List[torch.Tensor] = None
shiyi.c_98's avatar
shiyi.c_98 committed
309
    image_sizes: List[List[int]] = None
310
311
    image_offsets: List[int] = None

312
    # Batched sampling params
313
314
315
316
317
318
319
320
321
    temperatures: torch.Tensor = None
    top_ps: torch.Tensor = None
    top_ks: torch.Tensor = None
    frequency_penalties: torch.Tensor = None
    presence_penalties: torch.Tensor = None
    logit_bias: torch.Tensor = None

    @classmethod
    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
322
        return_logprob = any(req.return_logprob for req in reqs)
323
324
325
326
327
328

        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
329
            return_logprob=return_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
332
333
334
        )

    def is_empty(self):
        return len(self.reqs) == 0

335
    def has_stream(self) -> bool:
336
        # Return whether batch has at least 1 streaming request
337
338
        return any(r.stream for r in self.reqs)

339
    def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
342
343
344
345
346
347
348
349
350
351
352
        device = "cuda"
        bs = len(self.reqs)
        reqs = self.reqs
        input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
        prefix_indices = [r.prefix_indices for r in reqs]

        # Handle prefix
        flatten_input_ids = []
        extend_lens = []
        prefix_lens = []
        seq_lens = []

        req_pool_indices = self.req_to_token_pool.alloc(bs)
353
354

        if req_pool_indices is None:
zhyncs's avatar
zhyncs committed
355
356
357
358
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
359

Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        req_pool_indices_cpu = req_pool_indices.cpu().numpy()
        for i in range(bs):
            flatten_input_ids.extend(input_ids[i])
            extend_lens.append(len(input_ids[i]))

            if len(prefix_indices[i]) == 0:
                prefix_lens.append(0)
            else:
                prefix_lens.append(len(prefix_indices[i]))
                self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
                    : len(prefix_indices[i])
                ] = prefix_indices[i]

            seq_lens.append(prefix_lens[-1] + extend_lens[-1])

        position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)

377
        # Allocate memory
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
380
381
        seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
        extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
        out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
        if out_cache_loc is None:
Mingyi's avatar
Mingyi committed
382
            self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
383
            out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385

            if out_cache_loc is None:
Ying Sheng's avatar
Ying Sheng committed
386
                logger.error("Prefill out of memory. This should never happen.")
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
389
390
391
392
393
394
395
396
                self.tree_cache.pretty_print()
                exit()

        pt = 0
        for i in range(bs):
            self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
                prefix_lens[i] : prefix_lens[i] + extend_lens[i]
            ] = out_cache_loc[pt : pt + extend_lens[i]]
            pt += extend_lens[i]

397
398
        # Handle logit bias but only allocate when needed
        logit_bias = None
Lianmin Zheng's avatar
Lianmin Zheng committed
399
400
        for i in range(bs):
            if reqs[i].sampling_params.dtype == "int":
401
402
403
404
                if logit_bias is None:
                    logit_bias = torch.zeros(
                        (bs, vocab_size), dtype=torch.float32, device=device
                    )
405
                logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
408
409
410
411

        # Set fields
        self.input_ids = torch.tensor(
            flatten_input_ids, dtype=torch.int32, device=device
        )
        self.pixel_values = [r.pixel_values for r in reqs]
shiyi.c_98's avatar
shiyi.c_98 committed
412
        self.image_sizes = [r.image_size for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
413
414
415
416
417
418
419
420
421
        self.image_offsets = [
            r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
        ]
        self.req_pool_indices = req_pool_indices
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
        self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
        self.position_ids_offsets = position_ids_offsets
        self.extend_num_tokens = extend_num_tokens
        self.out_cache_loc = out_cache_loc
Liangsheng Yin's avatar
Liangsheng Yin committed
422
        self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
425
426
427
428
429
430

        self.temperatures = torch.tensor(
            [r.sampling_params.temperature for r in reqs],
            dtype=torch.float,
            device=device,
        ).view(-1, 1)
        self.top_ps = torch.tensor(
            [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
431
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
        self.top_ks = torch.tensor(
            [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
434
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
437
438
439
440
441
442
443
444
445
446
        self.frequency_penalties = torch.tensor(
            [r.sampling_params.frequency_penalty for r in reqs],
            dtype=torch.float,
            device=device,
        )
        self.presence_penalties = torch.tensor(
            [r.sampling_params.presence_penalty for r in reqs],
            dtype=torch.float,
            device=device,
        )
        self.logit_bias = logit_bias

447
448
    def check_decode_mem(self):
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
449
        if self.token_to_kv_pool.available_size() >= bs:
450
451
            return True

Mingyi's avatar
Mingyi committed
452
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
453

454
455
456
457
458
459
460
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
461
462

        # TODO(lsyin): improve retraction policy for radix cache
463
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
464
465
466
467
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
468
469
470
471
            reverse=True,
        )

        retracted_reqs = []
472
473
        seq_lens_cpu = self.seq_lens.cpu().numpy()
        req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
474
475
476
477
478
479
480
481
482
483
484
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

485
486
487
488
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

Liangsheng Yin's avatar
Liangsheng Yin committed
489
490
491
492
493
            # TODO: apply more fine-grained retraction
            last_uncached_pos = len(req.prefix_indices)
            token_indices = self.req_to_token_pool.req_to_token[
                req_pool_indices_cpu[idx]
            ][last_uncached_pos : seq_lens_cpu[idx]]
Mingyi's avatar
Mingyi committed
494
            self.token_to_kv_pool.free(token_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
495

Liangsheng Yin's avatar
Liangsheng Yin committed
496
            # release the last node
Liangsheng Yin's avatar
Liangsheng Yin committed
497
            self.tree_cache.dec_lock_ref(req.last_node)
Liangsheng Yin's avatar
Liangsheng Yin committed
498

499
500
            req.prefix_indices = None
            req.last_node = None
501
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
502
503
504
505

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
506

507
508
        self.filter_batch(sorted_indices)

Liangsheng Yin's avatar
Liangsheng Yin committed
509
510
511
512
513
514
515
516
517
518
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
519

Liangsheng Yin's avatar
Liangsheng Yin committed
520
    def check_for_jump_forward(self, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
521
        jump_forward_reqs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
522
523
524
525
526
        filter_indices = [i for i in range(len(self.reqs))]

        req_pool_indices_cpu = None

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
527
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
552
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
553
554
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
555
556
                        continue

sglang's avatar
sglang committed
557
558
559
560
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
561
562
563
564
565
566
567
568
569

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
570

571
572
573
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
574
                    # insert the old request into tree_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
575
576
                    if req_pool_indices_cpu is None:
                        req_pool_indices_cpu = self.req_pool_indices.tolist()
Liangsheng Yin's avatar
Liangsheng Yin committed
577
                    self.tree_cache.cache_req(
Liangsheng Yin's avatar
Liangsheng Yin committed
578
                        token_ids=cur_all_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
579
580
                        last_uncached_pos=len(req.prefix_indices),
                        req_pool_idx=req_pool_indices_cpu[i],
Liangsheng Yin's avatar
Liangsheng Yin committed
581
                    )
Liangsheng Yin's avatar
Liangsheng Yin committed
582
583
584

                    # unlock the last node
                    self.tree_cache.dec_lock_ref(req.last_node)
Liangsheng Yin's avatar
Liangsheng Yin committed
585

Liangsheng Yin's avatar
Liangsheng Yin committed
586
587
588
589
590
591
592
593
594
595
596
597
                    # re-applying image padding
                    if req.pixel_values is not None:
                        (
                            req.origin_input_ids,
                            req.image_offset,
                        ) = model_runner.model.pad_input_ids(
                            req.origin_input_ids_unpadded,
                            req.pad_value,
                            req.pixel_values.shape,
                            req.image_size,
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
598
                    jump_forward_reqs.append(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
599
600
601
602
603
                    filter_indices.remove(i)

        if len(filter_indices) < len(self.reqs):
            self.filter_batch(filter_indices)

Liangsheng Yin's avatar
Liangsheng Yin committed
604
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
605

606
    def prepare_for_decode(self, input_ids=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
609
610
611
612
613
614
615
616
        if input_ids is None:
            input_ids = [
                r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
            ]
        self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
        self.seq_lens.add_(1)
        self.prefix_lens = None

        # Alloc mem
        bs = len(self.reqs)
617
        self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
618

619
        if self.out_cache_loc is None:
Ying Sheng's avatar
Ying Sheng committed
620
            logger.error("Decode out of memory. This should never happen.")
621
622
            self.tree_cache.pretty_print()
            exit()
Lianmin Zheng's avatar
Lianmin Zheng committed
623
624
625
626
627
628
629
630
631
632
633
634
635

        self.req_to_token_pool.req_to_token[
            self.req_pool_indices, self.seq_lens - 1
        ] = self.out_cache_loc

    def filter_batch(self, unfinished_indices: List[int]):
        self.reqs = [self.reqs[i] for i in unfinished_indices]
        new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
        self.seq_lens = self.seq_lens[new_indices]
        self.input_ids = None
        self.req_pool_indices = self.req_pool_indices[new_indices]
        self.prefix_lens = None
        self.position_ids_offsets = self.position_ids_offsets[new_indices]
636
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
637
        self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
638
        self.return_logprob = any(req.return_logprob for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
641
642
643
644
645
646
647

        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "frequency_penalties",
            "presence_penalties",
            "logit_bias",
        ]:
648
            self_val = getattr(self, item, None)
Mingyi's avatar
Mingyi committed
649
            if self_val is not None:  # logit_bias can be None
650
                setattr(self, item, self_val[new_indices])
Lianmin Zheng's avatar
Lianmin Zheng committed
651

652
    def merge(self, other: "Batch"):
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654
655
656
657
658
659
660
661
662
        self.reqs.extend(other.reqs)

        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
        self.prefix_lens = None
        self.position_ids_offsets = torch.concat(
            [self.position_ids_offsets, other.position_ids_offsets]
        )
663
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
664
        self.top_logprobs_nums.extend(other.top_logprobs_nums)
665
        self.return_logprob = any(req.return_logprob for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
666
667
668
669
670
671
672
673

        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "frequency_penalties",
            "presence_penalties",
        ]:
674
675
676
677
678
679
680
681
682
683
            self_val = getattr(self, item, None)
            other_val = getattr(other, item, None)
            setattr(self, item, torch.concat([self_val, other_val]))

        # logit_bias can be None
        if self.logit_bias is not None or other.logit_bias is not None:
            vocab_size = (
                self.logit_bias.shape[1]
                if self.logit_bias is not None
                else other.logit_bias.shape[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
684
            )
685
686
687
688
689
690
691
692
693
            if self.logit_bias is None:
                self.logit_bias = torch.zeros(
                    (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
                )
            if other.logit_bias is None:
                other.logit_bias = torch.zeros(
                    (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
                )
            self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
Lianmin Zheng's avatar
Lianmin Zheng committed
694
695
696
697
698

    def sample(self, logits: torch.Tensor):
        # Post process logits
        logits = logits.contiguous()
        logits.div_(self.temperatures)
699
700
        if self.logit_bias is not None:
            logits.add_(self.logit_bias)
Lianmin Zheng's avatar
Lianmin Zheng committed
701
702
703
704
705
706
707
708

        has_regex = any(req.regex_fsm is not None for req in self.reqs)
        if has_regex:
            allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
            for i, req in enumerate(self.reqs):
                if req.regex_fsm is not None:
                    allowed_mask.zero_()
                    allowed_mask[
Liangsheng Yin's avatar
Liangsheng Yin committed
709
                        req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
710
711
712
713
714
                    ] = 1
                    logits[i].masked_fill_(~allowed_mask, float("-inf"))

        # TODO(lmzheng): apply penalty
        probs = torch.softmax(logits, dim=-1)
715

716
        if not global_server_args_dict["disable_flashinfer_sampling"]:
717
718
719
720
721
722
723
724
725
726
727
728
            max_top_k_round, batch_size = 32, probs.shape[0]
            uniform_samples = torch.rand(
                (max_top_k_round, batch_size), device=probs.device
            )
            batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
                probs, uniform_samples, self.top_ks, self.top_ps
            )
        else:
            # Here we provide a slower fallback implementation.
            batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
                probs, self.top_ks, self.top_ps
            )
729

730
        if not torch.all(success):
Ke Bao's avatar
Ke Bao committed
731
            warnings.warn("Sampling failed, fallback to top_k=1 strategy")
732
            probs = probs.masked_fill(torch.isnan(probs), 0.0)
Ke Bao's avatar
Ke Bao committed
733
734
735
736
            argmax_ids = torch.argmax(probs, dim=-1)
            batch_next_token_ids = torch.where(
                success, batch_next_token_ids, argmax_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
737
738
739
740
741

        if has_regex:
            batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
            for i, req in enumerate(self.reqs):
                if req.regex_fsm is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
742
                    req.regex_fsm_state = req.regex_fsm.get_next_state(
Lianmin Zheng's avatar
Lianmin Zheng committed
743
744
745
                        req.regex_fsm_state, batch_next_token_ids_cpu[i]
                    )

746
        return batch_next_token_ids
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761


@dataclass
class InputMetadata:
    """Store all inforamtion of a forward pass."""

    forward_mode: ForwardMode
    batch_size: int
    total_num_tokens: int
    req_pool_indices: torch.Tensor
    seq_lens: torch.Tensor
    positions: torch.Tensor
    req_to_token_pool: ReqToTokenPool
    token_to_kv_pool: TokenToKVPool

762
763
764
765
    # For extend
    extend_seq_lens: torch.Tensor
    extend_start_loc: torch.Tensor
    extend_no_prefix: bool
766

767
    # Output location of the KV cache
768
769
    out_cache_loc: torch.Tensor = None

770
    # Output options
771
772
773
    return_logprob: bool = False
    top_logprobs_nums: List[int] = None

774
775
776
777
778
779
780
    # Trition attention backend
    triton_max_seq_len: int = 0
    triton_max_extend_len: int = 0
    triton_start_loc: torch.Tensor = None
    triton_prefix_lens: torch.Tensor = None

    # FlashInfer attention backend
781
782
783
    flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
    flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
    flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
784
    flashinfer_use_ragged: bool = False
785
786
787
788
789
790
791
792
793
794
795
796
797

    @classmethod
    def create(
        cls,
        model_runner,
        forward_mode,
        req_pool_indices,
        seq_lens,
        prefix_lens,
        position_ids_offsets,
        out_cache_loc,
        top_logprobs_nums=None,
        return_logprob=False,
798
        skip_flashinfer_init=False,
799
    ):
800
        flashinfer_use_ragged = False
801
        if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
802
            if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
803
                flashinfer_use_ragged = True
804
805
806
807
808
809
810
            init_flashinfer_args(
                forward_mode,
                model_runner,
                req_pool_indices,
                seq_lens,
                prefix_lens,
                model_runner.flashinfer_decode_wrapper,
811
                flashinfer_use_ragged,
812
            )
813

814
815
816
817
        batch_size = len(req_pool_indices)

        if forward_mode == ForwardMode.DECODE:
            positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
818
819
820
821
822
823
824
            extend_seq_lens = extend_start_loc = extend_no_prefix = None
            if not model_runner.server_args.disable_flashinfer:
                # This variable is not needed in this case,
                # we do not compute it to make it compatbile with cuda graph.
                total_num_tokens = None
            else:
                total_num_tokens = int(torch.sum(seq_lens))
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
        else:
            seq_lens_cpu = seq_lens.cpu().numpy()
            prefix_lens_cpu = prefix_lens.cpu().numpy()
            position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
            positions = torch.tensor(
                np.concatenate(
                    [
                        np.arange(
                            prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
                            seq_lens_cpu[i] + position_ids_offsets_cpu[i],
                        )
                        for i in range(batch_size)
                    ],
                    axis=0,
                ),
                device="cuda",
            )
842
843
844
845
846
            extend_seq_lens = seq_lens - prefix_lens
            extend_start_loc = torch.zeros_like(seq_lens)
            extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
            extend_no_prefix = torch.all(prefix_lens == 0)
            total_num_tokens = int(torch.sum(seq_lens))
847
848
849
850
851
852
853
854
855
856
857

        ret = cls(
            forward_mode=forward_mode,
            batch_size=batch_size,
            total_num_tokens=total_num_tokens,
            req_pool_indices=req_pool_indices,
            seq_lens=seq_lens,
            positions=positions,
            req_to_token_pool=model_runner.req_to_token_pool,
            token_to_kv_pool=model_runner.token_to_kv_pool,
            out_cache_loc=out_cache_loc,
858
859
860
            extend_seq_lens=extend_seq_lens,
            extend_start_loc=extend_start_loc,
            extend_no_prefix=extend_no_prefix,
861
862
            return_logprob=return_logprob,
            top_logprobs_nums=top_logprobs_nums,
863
864
865
            flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
            flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
            flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
866
            flashinfer_use_ragged=flashinfer_use_ragged,
867
868
        )

869
        if model_runner.server_args.disable_flashinfer:
870
871
872
873
874
875
            (
                ret.triton_max_seq_len,
                ret.triton_max_extend_len,
                ret.triton_start_loc,
                ret.triton_prefix_lens,
            ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
876
877

        return ret
878
879


880
881
882
883
884
885
886
def init_flashinfer_args(
    forward_mode,
    model_runner,
    req_pool_indices,
    seq_lens,
    prefix_lens,
    flashinfer_decode_wrapper,
887
    flashinfer_use_ragged=False,
888
):
889
    """Init auxiliary variables for FlashInfer attention backend."""
890
891
892
893
    num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
    num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
    head_dim = model_runner.model_config.head_dim
    batch_size = len(req_pool_indices)
Ying Sheng's avatar
Ying Sheng committed
894
    total_num_tokens = int(torch.sum(seq_lens))
895

896
    if flashinfer_use_ragged:
897
        paged_kernel_lens = prefix_lens
898
899
    else:
        paged_kernel_lens = seq_lens
900

901
    kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
902
903
904
905
906
907
908
909
910
911
912
913
    kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
    req_pool_indices_cpu = req_pool_indices.cpu().numpy()
    paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
    kv_indices = torch.cat(
        [
            model_runner.req_to_token_pool.req_to_token[
                req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
            ]
            for i in range(batch_size)
        ],
        dim=0,
    ).contiguous()
914
    kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
915
916

    if forward_mode == ForwardMode.DECODE:
917
918
        flashinfer_decode_wrapper.end_forward()
        flashinfer_decode_wrapper.begin_forward(
919
920
921
922
923
924
925
926
927
928
            kv_indptr,
            kv_indices,
            kv_last_page_len,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            1,
        )
    else:
        # extend part
929
        qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
930
931
        qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)

932
        if flashinfer_use_ragged:
933
934
935
936
937
938
939
940
            model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
            model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
                qo_indptr,
                qo_indptr,
                num_qo_heads,
                num_kv_heads,
                head_dim,
            )
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956

        # cached part
        model_runner.flashinfer_prefill_wrapper_paged.end_forward()
        model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
            qo_indptr,
            kv_indptr,
            kv_indices,
            kv_last_page_len,
            num_qo_heads,
            num_kv_heads,
            head_dim,
            1,
        )


def init_triton_args(forward_mode, seq_lens, prefix_lens):
957
    """Init auxiliary variables for triton attention backend."""
958
959
960
961
962
963
964
965
966
967
968
969
    batch_size = len(seq_lens)
    max_seq_len = int(torch.max(seq_lens))
    start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
    start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)

    if forward_mode == ForwardMode.DECODE:
        max_extend_len = None
    else:
        extend_seq_lens = seq_lens - prefix_lens
        max_extend_len = int(torch.max(extend_seq_lens))

    return max_seq_len, max_extend_len, start_loc, prefix_lens
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995


def top_k_top_p_sampling_from_probs_torch(
    probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
):
    """A top-k and top-k sampling implementation with native pytorch operations."""
    probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
    probs_sort[
        torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
        >= top_ks.view(-1, 1)
    ] = 0.0
    probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
    try:
        sampled_index = torch.multinomial(probs_sort, num_samples=1)
    except RuntimeError:
        batch_next_token_ids = torch.zeros(
            (probs_sort.shape[0],), dtype=torch.int64, device=probs.device
        )
        success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
        return batch_next_token_ids, success

    batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
    success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
    return batch_next_token_ids, success