schedule_batch.py 38.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37

Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.global_config import global_config
39
from sglang.srt.configs.model_config import ModelConfig
40
from sglang.srt.constrained.grammar import Grammar
41
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
44
from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
46
from sglang.srt.sampling.sampling_params import SamplingParams
47
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
48
49

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
50

51
52
# Put some global args for easy access
global_server_args_dict = {
53
54
55
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
56
    "disable_mla": ServerArgs.disable_mla,
57
    "torchao_config": ServerArgs.torchao_config,
58
    "disable_nan_detection": ServerArgs.disable_nan_detection,
59
60
}

Lianmin Zheng's avatar
Lianmin Zheng committed
61

Ying Sheng's avatar
Ying Sheng committed
62
63
64
logger = logging.getLogger(__name__)


65
66
67
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
68

69
    def to_json(self):
70
        raise NotImplementedError()
71
72
73


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
74
    def __init__(self, matched: Union[int, List[int]]):
75
76
77
        super().__init__()
        self.matched = matched

78
79
80
81
82
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
83
84


85
86
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
87
        super().__init__()
88
        self.matched = matched
89

90
91
92
93
94
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
95
96


97
98
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
99
        super().__init__()
100
        self.length = length
101

102
103
104
105
106
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
107
108
109
110
111
112


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

113
114
115
116
    def to_json(self):
        return {
            "type": "abort",
        }
117

Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
120
class ImageInputs:
121
122
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
123
    pixel_values: torch.Tensor
124
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
125
126
127
128
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
129
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
Yineng Zhang's avatar
Yineng Zhang committed
134
135
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
138
139
140
141

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
142
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
143
        )
144
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
147
148
149
150
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
151
152
153
154
155
156
157
158
159
160
161
162

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
163
164
165
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
166
class Req:
167
    """The input and output status of a request."""
168

169
170
171
172
173
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
174
        sampling_params: SamplingParams,
175
176
        lora_path: Optional[str] = None,
    ):
177
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
178
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
179
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
180
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
181
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
182
        self.output_ids = []  # Each decode stage's output ids
183
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
184
185

        self.sampling_params = sampling_params
186
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
187

188
189
190
        # Memory info
        self.req_pool_idx = None

191
192
193
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
194
        self.stream = False
195

196
        # For incremental decoding
197
198
199
200
201
202
203
204
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
205
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
206
207
208
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
209

210
211
212
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
213

214
        # For vision inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
215
        self.image_inputs: Optional[ImageInputs] = None
216

217
218
        # Prefix info
        self.prefix_indices = []
219
        self.extend_input_len = 0
220
        self.last_node = None
221
        self.is_being_chunked = 0
222

223
224
225
        # For retraction
        self.is_retracted = False

226
        # Logprobs (arguments)
227
228
229
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
230
231

        # Logprobs (return value)
232
        self.normalized_prompt_logprob = None
233
234
235
236
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
237
238

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
239
240
241
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
242
243
244
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

245
        # Embedding (return values)
246
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
247

248
        # Constrained decoding
249
        self.grammar: Optional[Grammar] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
250

251
252
253
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

Yineng Zhang's avatar
Yineng Zhang committed
254
255
256
        # For Qwen2-VL
        self.mrope_position_delta = []  # use mutable object

257
258
259
260
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

261
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
262
        self.fill_ids = self.origin_input_ids + self.output_ids
263
264
265
266
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
267
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
268

269
    def adjust_max_prefix_ids(self):
270
271
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
272
273
274
275

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
278
279
280

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

281
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
285
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
286

287
        max_prefix_len = max(max_prefix_len, 0)
288
        return self.fill_ids[:max_prefix_len]
289

Liangsheng Yin's avatar
Liangsheng Yin committed
290
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
291
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
296
297
298
299
300
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
301
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
302

303
    def get_next_inc_detokenization(self):
304
305
        if self.tokenizer is None:
            return False, ""
306
307
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
308
309
310
311
312

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
313
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318
319
320
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
321
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
322
323

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
324

325
    def check_finished(self):
326
        if self.finished():
327
328
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
329
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
330
331
332
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
333
334
            return

335
        last_token_id = self.output_ids[-1]
336

337
        matched_eos = False
338

339
340
341
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
342
343
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
344
345
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
346
        if matched_eos and not self.sampling_params.ignore_eos:
347
348
349
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

350
        # Check stop strings
351
352
353
354
355
356
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
357
                if stop_str in tail_str or stop_str in self.decoded_text:
358
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
359
360
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
361
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
362
363
        assert self.grammar is not None and self.tokenizer is not None

Liangsheng Yin's avatar
Liangsheng Yin committed
364
365
366
367
368
369
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
370
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
371
        all_ids = self.tokenizer.encode(all_text)
372
        if not all_ids:
havetc's avatar
havetc committed
373
            logger.warning("Encoded all_text resulted in empty all_ids")
374
375
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
376
        prompt_tokens = len(self.origin_input_ids_unpadded)
377
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
378
            logger.warning("prompt_tokens is larger than encoded all_ids")
379
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
380
381
382

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
383
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
402

403
404
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
405
406
407
408

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
411
412
413
                    k = k + 1
                else:
                    break
414
415
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
416
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
417
            self.last_update_decode_tokens = len(self.output_ids) - k
418

Liangsheng Yin's avatar
Liangsheng Yin committed
419
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
420

Lianmin Zheng's avatar
Lianmin Zheng committed
421
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
422
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424


425
426
427
bid = 0


428
@dataclasses.dataclass
429
class ScheduleBatch:
430
431
    """Store all inforamtion of a batch."""

432
    # Request, memory pool, and cache
433
    reqs: List[Req]
434
435
436
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
437
438
439
440

    # For utility
    model_config: ModelConfig = None

Liangsheng Yin's avatar
Liangsheng Yin committed
441
    forward_mode: ForwardMode = None
442
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
443

444
    # Batched arguments to model runner
445
446
447
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
448
    # The output locations of the KV cache
449
    out_cache_loc: torch.Tensor = None
450
451
    output_ids: torch.Tensor = None

452
453
454
    # The sum of all sequence lengths
    seq_lens_sum: int = None

455
    # For processing logprobs
456
    return_logprob: bool = False
457
458
459
460
461
462
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
463
    decoding_reqs: List[Req] = None
464

465
466
467
468
469
470
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

471
472
473
    # Stream
    has_stream: bool = False

474
475
    # Has grammar
    has_grammar: bool = False
476

477
478
479
    # device
    device: str = "cuda"

480
    @classmethod
481
482
    def init_new(
        cls,
483
        reqs: List[Req],
484
485
486
487
488
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
489
490
491
492
493
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
494
            model_config=model_config,
495
496
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
497
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
498
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
499
500
        )

501
    def batch_size(self):
502
        return len(self.reqs)
503

Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
506
    def is_empty(self):
        return len(self.reqs) == 0

507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
525
526
527
528
529
530
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
531
532
533
534
535
536
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
567
                # NOTE: the encoder part should be considered as a whole
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
            self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
            self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
609
610
        self.forward_mode = ForwardMode.EXTEND

611
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
612
        reqs = self.reqs
613
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
614
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
        seq_lens = []

617
        # Allocate memory
618
        req_pool_indices = self.alloc_req_slots(bs)
619
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
620

621
        pt = 0
622
        for i, req in enumerate(reqs):
623
624
625
626
627
628
629
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

630
            req.req_pool_idx = req_pool_indices[i]
631
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
632
            seq_lens.append(seq_len)
633
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
634

635
            if pre_len > 0:
636
637
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
638
                )
639
640
641
            self.req_to_token_pool.write(
                (req.req_pool_idx, slice(pre_len, seq_len)),
                out_cache_loc[pt : pt + req.extend_input_len],
642
            )
643
644
645
646
647
648
649
650
651
652
653

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
654
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
655
656

        # Set fields
657
658
659
660
661
662
663
664
665
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
666

Lianmin Zheng's avatar
Lianmin Zheng committed
667
        self.out_cache_loc = out_cache_loc
668
669

        self.seq_lens_sum = sum(seq_lens)
670
671
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
672
        self.extend_num_tokens = extend_num_tokens
673
674
675
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
676

677
678
679
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

680
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
681
682
683
            self,
            self.model_config.vocab_size,
            global_server_args_dict["disable_penalizer"],
684
        )
685

686
    def mix_with_running(self, running_batch: "ScheduleBatch"):
687
        self.forward_mode = ForwardMode.MIXED
688
        running_bs = running_batch.batch_size()
689
690
691
692
693

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

694
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
695
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
696

697
        self.merge_batch(running_batch)
698
699
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
700
        self.extend_num_tokens += running_bs
701
702

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
703
        self.prefix_lens.extend(
704
705
706
707
708
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
709
710
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
711

712
    def check_decode_mem(self):
713
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
714
        if self.token_to_kv_pool.available_size() >= bs:
715
716
            return True

Mingyi's avatar
Mingyi committed
717
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
718

719
720
721
722
723
724
725
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
726
727

        # TODO(lsyin): improve retraction policy for radix cache
728
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
729
730
731
732
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
733
734
735
736
            reverse=True,
        )

        retracted_reqs = []
737
        seq_lens_cpu = self.seq_lens.cpu().numpy()
738
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
739
740
741
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
742
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
743
744
745
746
747
748
749
750
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

751
            first_iter = False
752
753
754
755
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

756
757
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
758
759
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
760
                ]
761
                self.token_to_kv_pool.free(token_indices)
762
                self.req_to_token_pool.free(req.req_pool_idx)
763
764
765
766
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
767
768
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
769
                ]
770
                self.token_to_kv_pool.free(token_indices)
771
                self.req_to_token_pool.free(req.req_pool_idx)
772
773
774
775
776
777
778
779
780
781
782

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
783

784
            req.prefix_indices = []
785
            req.last_node = None
786
            req.extend_input_len = 0
787
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
788
789
790
791

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
792

793
        self.filter_batch(keep_indices=sorted_indices)
794

Liangsheng Yin's avatar
Liangsheng Yin committed
795
796
797
798
799
800
801
802
803
804
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
805

806
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
807
        jump_forward_reqs = []
808
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
809
810

        for i, req in enumerate(self.reqs):
811
812
813
814
            if req.grammar is not None:
                jump_helper = req.grammar.try_jump(req.tokenizer)
                if jump_helper.can_jump():
                    suffix_ids = jump_helper.suffix_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
815
816
817
818
819
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
820
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
821
822
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
823
824
                        continue

sglang's avatar
sglang committed
825
826
827
                    (
                        jump_forward_str,
                        next_state,
828
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
829
830
831
832
833
834
835

                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
836

837
838
839
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
840
                    # insert the old request into tree_cache
841
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
842

Liangsheng Yin's avatar
Liangsheng Yin committed
843
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
844
                    if req.image_inputs is not None:
845
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
846
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
847
848
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
849
                    jump_forward_reqs.append(req)
850
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
851

852
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
853

Liangsheng Yin's avatar
Liangsheng Yin committed
854
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
855

856
857
858
859
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

860
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
861
862
        self.forward_mode = ForwardMode.DECODE

863
864
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
865
866
867
868
        if self.sampling_info.penalizer_orchestrator:
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870

        # Alloc mem
871
        bs = len(self.reqs)
872
        self.out_cache_loc = self.alloc_token_slots(bs)
873

874
875
876
877
878
879
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

880
881
882
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
883
                (self.req_pool_indices, locs), self.out_cache_loc
884
885
886
887
888
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
889
                (self.req_pool_indices, locs), self.out_cache_loc
890
891
            )
            self.seq_lens.add_(1)
892
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
893

894
895
    def filter_batch(
        self,
896
        being_chunked_req: Optional[Req] = None,
897
898
899
900
901
902
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
903
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
904
905
906
            ]

        if keep_indices is None or len(keep_indices) == 0:
907
908
909
910
            # Filter out all requests
            self.reqs = []
            return

911
        if len(keep_indices) == len(self.reqs):
912
913
914
            # No need to filter
            return

915
916
917
918
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

919
        self.reqs = [self.reqs[i] for i in keep_indices]
920
921
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
922
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
923
        self.req_pool_indices = self.req_pool_indices[new_indices]
924
        self.seq_lens = self.seq_lens[new_indices]
925
        self.out_cache_loc = None
926
        self.seq_lens_sum = self.seq_lens.sum().item()
927
        self.output_ids = self.output_ids[new_indices]
928
        self.return_logprob = any(req.return_logprob for req in self.reqs)
929
        if self.return_logprob:
930
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
931
932
        else:
            self.top_logprobs_nums = None
933

934
        self.has_stream = any(req.stream for req in self.reqs)
935
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
936

937
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
938

939
    def merge_batch(self, other: "ScheduleBatch"):
940
941
942
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
943
        self.sampling_info.merge_batch(other.sampling_info)
944

945
946
947
948
949
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
950
951
952
953
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
954
        self.out_cache_loc = None
955
        self.seq_lens_sum += other.seq_lens_sum
956
957
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
958
959
960
961
962
963
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
964
        self.reqs.extend(other.reqs)
965

966
        self.return_logprob = self.return_logprob or other.return_logprob
967
        self.has_stream = self.has_stream or other.has_stream
968
        self.has_grammar = self.has_grammar or other.has_grammar
969
970
971

    def get_model_worker_batch(self):
        if self.forward_mode.is_decode():
972
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
973
974
975
976
977
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

978
979
        if self.has_grammar:
            self.sampling_info.grammars = [req.grammar for req in self.reqs]
980
        else:
981
            self.sampling_info.grammars = None
982

983
984
985
        global bid
        bid += 1

Yineng Zhang's avatar
Yineng Zhang committed
986
987
        mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]

988
        return ModelWorkerBatch(
989
            bid=bid,
990
991
992
993
994
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
995
            seq_lens_sum=self.seq_lens_sum,
996
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
997
998
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
999
            extend_num_tokens=self.extend_num_tokens,
1000
1001
1002
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
1003
1004
1005
1006
1007
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1008
            lora_paths=[req.lora_path for req in self.reqs],
1009
            sampling_info=self.sampling_info,
Yineng Zhang's avatar
Yineng Zhang committed
1010
            mrope_positions_delta=mrope_positions_delta,
1011
1012
        )

1013
    def copy(self):
1014
        # Only contain fields that will be used by process_batch_result
1015
1016
        return ScheduleBatch(
            reqs=self.reqs,
1017
            model_config=self.model_config,
1018
            forward_mode=self.forward_mode,
1019
1020
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1021
            decoding_reqs=self.decoding_reqs,
1022
1023
1024
1025
1026
1027
1028
1029
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1030

1031
@dataclasses.dataclass
1032
class ModelWorkerBatch:
1033
1034
    # The batch id
    bid: int
1035
1036
1037
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1038
    input_ids: torch.Tensor
1039
1040
1041
1042
1043
1044
1045
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1046
1047
1048
    # The sum of all sequence lengths
    seq_lens_sum: int

1049
1050
1051
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1052
1053
1054
1055
1056
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

    # For extend
1057
    extend_num_tokens: Optional[int]
1058
1059
1060
1061
1062
1063
1064
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1065
1066
1067
1068
1069
1070
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1071
1072
1073
1074
1075
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1076

Yineng Zhang's avatar
Yineng Zhang committed
1077
1078
1079
    # For Qwen2-VL
    mrope_positions_delta: List[List[int]]

1080
    def copy(self):
1081
        return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

    def to(self, device: str):
        self.input_ids = self.input_ids.to(device, non_blocking=True)
        self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
        self.seq_lens = self.seq_lens.to(device, non_blocking=True)
        self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
        self.req_to_token_pool_records = [
            (x, y.to(device, non_blocking=True))
            for x, y in self.req_to_token_pool_records
        ]
        self.sampling_info.to(device)