schedule_batch.py 26.7 KB
Newer Older
Liangsheng Yin's avatar
Liangsheng Yin committed
1
2
from __future__ import annotations

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

18
"""Meta data for requests and batches"""
Lianmin Zheng's avatar
Lianmin Zheng committed
19

Ying Sheng's avatar
Ying Sheng committed
20
import logging
21
from dataclasses import dataclass
22
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24

import torch
25

Liangsheng Yin's avatar
Liangsheng Yin committed
26
from sglang.global_config import global_config
27
28
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
29
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Liangsheng Yin's avatar
Liangsheng Yin committed
32
from sglang.srt.model_executor.forward_batch_info import ForwardMode
33
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
34
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
39
# Put some global args for easy access
global_server_args_dict = {
40
41
42
43
44
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
    "enable_mla": ServerArgs.enable_mla,
    "torchao_config": ServerArgs.torchao_config,
45
46
}

Lianmin Zheng's avatar
Lianmin Zheng committed
47

Ying Sheng's avatar
Ying Sheng committed
48
49
50
logger = logging.getLogger(__name__)


51
52
53
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
54

55
    def to_json(self):
56
        raise NotImplementedError()
57
58
59


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
60
    def __init__(self, matched: Union[int, List[int]]):
61
62
63
        super().__init__()
        self.matched = matched

64
65
66
67
68
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
69
70


71
72
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
73
        super().__init__()
74
        self.matched = matched
75

76
77
78
79
80
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
81
82


83
84
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
85
        super().__init__()
86
        self.length = length
87

88
89
90
91
92
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
93
94
95
96
97
98


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

99
100
101
102
    def to_json(self):
        return {
            "type": "abort",
        }
103

Lianmin Zheng's avatar
Lianmin Zheng committed
104
105

class Req:
106
107
    """Store all inforamtion of a request."""

108
109
110
111
112
113
114
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
        lora_path: Optional[str] = None,
    ):
115
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
116
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
117
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
118
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
119
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
120
        self.output_ids = []  # Each decode stage's output ids
121
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
122
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
123

124
125
126
        # Memory info
        self.req_pool_idx = None

127
128
129
130
        # Check finish
        self.tokenizer = None
        self.finished_reason = None

131
        # For incremental decoding
132
133
134
135
136
137
138
139
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
140
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
143
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
144

145
146
147
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
148

149
        # For vision inputs
Lianmin Zheng's avatar
Lianmin Zheng committed
150
        self.pixel_values = None
151
152
        self.image_sizes = None
        self.image_offsets = None
153
        self.pad_value = None
154
        self.modalities = None
155

156
157
        # Prefix info
        self.prefix_indices = []
158
        self.extend_input_len = 0
159
160
        self.last_node = None

161
        # Sampling parameters
Lianmin Zheng's avatar
Lianmin Zheng committed
162
163
164
        self.sampling_params = None
        self.stream = False

165
        # Logprobs (arguments)
166
167
168
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
169
170

        # Logprobs (return value)
171
        self.normalized_prompt_logprob = None
172
173
174
175
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
176
177

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
178
179
180
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
181
182
183
184
185
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

        # Embedding
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
186

187
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
188
189
190
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
191

192
193
194
195
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

196
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
197
        self.fill_ids = self.origin_input_ids + self.output_ids
198
199
200
201
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
202
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
203

204
    def adjust_max_prefix_ids(self):
205
206
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
207
208
209
210

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
211
212
213
214
215

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

216
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
217
218
219
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
220
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
221

222
        max_prefix_len = max(max_prefix_len, 0)
223
        return self.fill_ids[:max_prefix_len]
224

Liangsheng Yin's avatar
Liangsheng Yin committed
225
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
226
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
229
230
231
232
233
234
235
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
236
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
237

238
    def get_next_inc_detokenization(self):
239
240
        if self.tokenizer is None:
            return False, ""
241
242
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
243
244
245
246
247

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
248
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
249
250
251
252
253
254
255
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
256
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
257
258

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
259

260
    def check_finished(self):
261
        if self.finished():
262
263
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
264
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
265
266
267
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
268
269
            return

270
        last_token_id = self.output_ids[-1]
271
272
273
274
275
276

        matched_eos = last_token_id in self.sampling_params.stop_token_ids

        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id

277
        if matched_eos and not self.sampling_params.ignore_eos:
278
279
280
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

281
282
283
284
285
286
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
287
                if stop_str in tail_str or stop_str in self.decoded_text:
288
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
289
290
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
291
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
296
297
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
298
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
299
        all_ids = self.tokenizer.encode(all_text)
300
        if not all_ids:
havetc's avatar
havetc committed
301
            logger.warning("Encoded all_text resulted in empty all_ids")
302
303
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
304
        prompt_tokens = len(self.origin_input_ids_unpadded)
305
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
306
            logger.warning("prompt_tokens is larger than encoded all_ids")
307
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
308
309
310

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
311
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
330
331
332
333
334
335

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
336
337
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
338
339
340
                    k = k + 1
                else:
                    break
341
342
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
343
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
344
            self.last_update_decode_tokens = len(self.output_ids) - k
345

Liangsheng Yin's avatar
Liangsheng Yin committed
346
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
347

Lianmin Zheng's avatar
Lianmin Zheng committed
348
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
349
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
350
351


352
@dataclass
353
class ScheduleBatch:
354
355
    """Store all inforamtion of a batch."""

356
    # Request, memory pool, and cache
357
358
    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool
359
    token_to_kv_pool: BaseTokenToKVPool
360
    tree_cache: BasePrefixCache
361

Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
    forward_mode: ForwardMode = None

364
    # Batched arguments to model runner
365
366
367
368
369
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
    position_ids_offsets: torch.Tensor = None
    out_cache_loc: torch.Tensor = None
370
    extend_num_tokens: int = None
Liangsheng Yin's avatar
Liangsheng Yin committed
371

372
373
    # For mixed chunekd prefill
    prefix_lens_cpu: List[int] = None
374
    running_bs: int = None
375

376
    # For processing logprobs
377
    return_logprob: bool = False
378
    top_logprobs_nums: List[int] = None
379

380
381
382
    # Stream
    has_stream: bool = False

383
384
    @classmethod
    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
385
        return_logprob = any(req.return_logprob for req in reqs)
386
        has_stream = any(req.stream for req in reqs)
387
388
389
390
391
392

        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
393
            return_logprob=return_logprob,
394
            has_stream=has_stream,
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
        )

397
    def batch_size(self):
398
        return len(self.reqs)
399

Lianmin Zheng's avatar
Lianmin Zheng committed
400
401
402
    def is_empty(self):
        return len(self.reqs) == 0

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
                logger.error("Prefill out of memory. Try to lower your batch size.")
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

428
    def prepare_for_extend(self, vocab_size: int):
Liangsheng Yin's avatar
Liangsheng Yin committed
429
430
        self.forward_mode = ForwardMode.EXTEND

431
        bs = self.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
432
        reqs = self.reqs
433
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
434
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
435
436
        seq_lens = []

437
        # Allocate memory
438
        req_pool_indices_cpu = self.alloc_req_slots(bs)
439
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
440

441
        pt = 0
442
443
        for i, req in enumerate(reqs):
            req.req_pool_idx = req_pool_indices_cpu[i]
444
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
445
            seq_lens.append(seq_len)
446
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
447

448
            if pre_len > 0:
449
                self.req_to_token_pool.req_to_token[req.req_pool_idx][
450
451
                    :pre_len
                ] = req.prefix_indices
Lianmin Zheng's avatar
Lianmin Zheng committed
452

453
            self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
454
                out_cache_loc[pt : pt + req.extend_input_len]
455
            )
456
457
458
459
460
461
462
463
464
465
466

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468

        # Set fields
469
470
471
472
        with torch.device("cuda"):
            self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
            self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
            self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
473
474
            self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)

Lianmin Zheng's avatar
Lianmin Zheng committed
475
476
        self.extend_num_tokens = extend_num_tokens
        self.out_cache_loc = out_cache_loc
Liangsheng Yin's avatar
Liangsheng Yin committed
477
        self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
478
        self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
479
480
        self.extend_lens_cpu = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
481
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
482

483
    def mix_with_running(self, running_batch: "ScheduleBatch"):
484
        self.forward_mode = ForwardMode.MIXED
485
        running_bs = running_batch.batch_size()
486
487
488
489
490
491
492

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
493
494
        extend_num_tokens = self.extend_num_tokens + running_bs

495
496
497
498
        self.merge(running_batch)
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
        self.extend_num_tokens = extend_num_tokens
499
500
501
502
503
504
505
506
507
508

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
        self.prefix_lens_cpu.extend(
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
        self.extend_lens_cpu.extend([1] * running_bs)
        self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
509

510
    def check_decode_mem(self):
511
        bs = self.batch_size()
Ying Sheng's avatar
Ying Sheng committed
512
        if self.token_to_kv_pool.available_size() >= bs:
513
514
            return True

Mingyi's avatar
Mingyi committed
515
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
516

517
518
519
520
521
522
523
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
524
525

        # TODO(lsyin): improve retraction policy for radix cache
526
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
527
528
529
530
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
531
532
533
534
            reverse=True,
        )

        retracted_reqs = []
535
        seq_lens_cpu = self.seq_lens.cpu().numpy()
Liangsheng Yin's avatar
Liangsheng Yin committed
536
537
538
539
540
541
542
543
544
545
546
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

547
548
549
550
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

551
552
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
553
554
555
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    : seq_lens_cpu[idx]
                ]
556
                self.token_to_kv_pool.free(token_indices)
557
                self.req_to_token_pool.free(req.req_pool_idx)
558
559
560
561
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
562
563
564
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    last_uncached_pos : seq_lens_cpu[idx]
                ]
565
                self.token_to_kv_pool.free(token_indices)
566
                self.req_to_token_pool.free(req.req_pool_idx)
567
568
569
570
571
572
573
574
575
576
577

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
578

579
            req.prefix_indices = []
580
            req.last_node = None
581
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
582
583
584
585

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
586

587
588
        self.filter_batch(sorted_indices)

Liangsheng Yin's avatar
Liangsheng Yin committed
589
590
591
592
593
594
595
596
597
598
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
599

Liangsheng Yin's avatar
Liangsheng Yin committed
600
    def check_for_jump_forward(self, model_runner):
Liangsheng Yin's avatar
Liangsheng Yin committed
601
        jump_forward_reqs = []
Liangsheng Yin's avatar
Liangsheng Yin committed
602
603
604
        filter_indices = [i for i in range(len(self.reqs))]

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
605
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
630
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
631
632
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
633
634
                        continue

sglang's avatar
sglang committed
635
636
637
638
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
639
640
641
642
643
644
645
646
647

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
648

649
650
651
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
652
                    # insert the old request into tree_cache
653
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
654

Liangsheng Yin's avatar
Liangsheng Yin committed
655
656
657
658
                    # re-applying image padding
                    if req.pixel_values is not None:
                        (
                            req.origin_input_ids,
659
                            req.image_offsets,
Liangsheng Yin's avatar
Liangsheng Yin committed
660
661
662
                        ) = model_runner.model.pad_input_ids(
                            req.origin_input_ids_unpadded,
                            req.pad_value,
663
664
                            req.pixel_values,
                            req.image_sizes,
Liangsheng Yin's avatar
Liangsheng Yin committed
665
666
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
667
                    jump_forward_reqs.append(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
668
669
                    filter_indices.remove(i)

670
        self.filter_batch(filter_indices)
Liangsheng Yin's avatar
Liangsheng Yin committed
671

Liangsheng Yin's avatar
Liangsheng Yin committed
672
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
673

674
    def prepare_for_decode(self, input_ids=None):
Liangsheng Yin's avatar
Liangsheng Yin committed
675
676
        self.forward_mode = ForwardMode.DECODE

Lianmin Zheng's avatar
Lianmin Zheng committed
677
678
        if input_ids is None:
            input_ids = [
679
680
                r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
                for r in self.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
681
            ]
682
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
683
            self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
684

Lianmin Zheng's avatar
Lianmin Zheng committed
685
686
687
688
        self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
        self.seq_lens.add_(1)

        # Alloc mem
689
690
        bs = self.batch_size()
        self.out_cache_loc = self.alloc_token_slots(bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
691
692
693
694
695
696

        self.req_to_token_pool.req_to_token[
            self.req_pool_indices, self.seq_lens - 1
        ] = self.out_cache_loc

    def filter_batch(self, unfinished_indices: List[int]):
697
698
699
700
701
702
703
704
705
        if unfinished_indices is None or len(unfinished_indices) == 0:
            # Filter out all requests
            self.reqs = []
            return

        if len(unfinished_indices) == len(self.reqs):
            # No need to filter
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
706
707
708
709
710
711
        self.reqs = [self.reqs[i] for i in unfinished_indices]
        new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
        self.seq_lens = self.seq_lens[new_indices]
        self.input_ids = None
        self.req_pool_indices = self.req_pool_indices[new_indices]
        self.position_ids_offsets = self.position_ids_offsets[new_indices]
712
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
713
        self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
714
        self.return_logprob = any(req.return_logprob for req in self.reqs)
715
        self.has_stream = any(req.stream for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
716

717
        self.sampling_info.filter(unfinished_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
718

719
    def merge(self, other: "ScheduleBatch"):
720
721
722
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
723
        self.sampling_info.merge(other.sampling_info)
724

Lianmin Zheng's avatar
Lianmin Zheng committed
725
726
727
728
729
730
731
732
        self.reqs.extend(other.reqs)
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
        self.position_ids_offsets = torch.concat(
            [self.position_ids_offsets, other.position_ids_offsets]
        )
733
        self.out_cache_loc = None
Liangsheng Yin's avatar
Liangsheng Yin committed
734
        self.top_logprobs_nums.extend(other.top_logprobs_nums)
735
        self.return_logprob = any(req.return_logprob for req in self.reqs)
736
        self.has_stream = any(req.stream for req in self.reqs)