"vscode:/vscode.git/clone" did not exist on "d75096587df1a1d32f303fbbd6db6fe92bf06e1b"
schedule_batch.py 38.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
27
  It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
  It will be transformed from CPU scheduler to GPU model runner.
28
29
30
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
import dataclasses
Ying Sheng's avatar
Ying Sheng committed
33
import logging
34
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36

import torch
37

Liangsheng Yin's avatar
Liangsheng Yin committed
38
from sglang.global_config import global_config
39
from sglang.srt.configs.model_config import ModelConfig
40
from sglang.srt.constrained.grammar import Grammar
41
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
44
from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
46
from sglang.srt.sampling.sampling_params import SamplingParams
47
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
48
49

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
50

51
52
# Put some global args for easy access
global_server_args_dict = {
53
54
55
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
56
    "disable_mla": ServerArgs.disable_mla,
57
    "torchao_config": ServerArgs.torchao_config,
58
    "disable_nan_detection": ServerArgs.disable_nan_detection,
59
60
}

Lianmin Zheng's avatar
Lianmin Zheng committed
61

Ying Sheng's avatar
Ying Sheng committed
62
63
64
logger = logging.getLogger(__name__)


65
66
67
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
68

69
    def to_json(self):
70
        raise NotImplementedError()
71
72
73


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
74
    def __init__(self, matched: Union[int, List[int]]):
75
76
77
        super().__init__()
        self.matched = matched

78
79
80
81
82
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
83
84


85
86
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
87
        super().__init__()
88
        self.matched = matched
89

90
91
92
93
94
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
95
96


97
98
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
99
        super().__init__()
100
        self.length = length
101

102
103
104
105
106
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
107
108
109
110
111
112


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

113
114
115
116
    def to_json(self):
        return {
            "type": "abort",
        }
117

Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
@dataclasses.dataclass
Liangsheng Yin's avatar
Liangsheng Yin committed
120
class ImageInputs:
121
122
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
123
    pixel_values: torch.Tensor
124
    image_hashes: Optional[list] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
125
126
127
128
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None
129
    num_image_tokens: Optional[int] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
132
133

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None
Yineng Zhang's avatar
Yineng Zhang committed
134
135
    # QWen2-VL related
    image_grid_thws: List[Tuple[int, int, int]] = None
136
    mrope_position_delta: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
137
138
139
140
141
142

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
143
            image_hashes=hash(tuple(obj["image_hashes"])),
Liangsheng Yin's avatar
Liangsheng Yin committed
144
        )
145
        image_hash = ret.image_hashes
Liangsheng Yin's avatar
Liangsheng Yin committed
146
147
148
149
150
151
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
152
153
154
155
156
157
158
159
160
161
162
163

        optional_args = [
            "image_sizes",
            "modalities",
            "aspect_ratio_ids",
            "aspect_ratio_mask",
            "image_grid_thws",
        ]
        for arg in optional_args:
            if arg in obj:
                setattr(ret, arg, obj[arg])

Liangsheng Yin's avatar
Liangsheng Yin committed
164
165
166
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
167
class Req:
168
    """The input and output status of a request."""
169

170
171
172
173
174
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
175
        sampling_params: SamplingParams,
176
177
        lora_path: Optional[str] = None,
    ):
178
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
179
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
180
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
181
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
182
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
183
        self.output_ids = []  # Each decode stage's output ids
184
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
185
186

        self.sampling_params = sampling_params
187
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
188

189
190
191
        # Memory info
        self.req_pool_idx = None

192
193
194
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
195
        self.stream = False
196

197
        # For incremental decoding
198
199
200
201
202
203
204
205
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
206
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
207
208
209
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
210

211
212
213
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
214

215
        # For multimodal inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
216
        self.image_inputs: Optional[ImageInputs] = None
217

218
219
        # Prefix info
        self.prefix_indices = []
220
        self.extend_input_len = 0
221
        self.last_node = None
222
        self.is_being_chunked = 0
223

224
225
226
        # For retraction
        self.is_retracted = False

227
        # Logprobs (arguments)
228
229
230
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
231
232

        # Logprobs (return value)
233
        self.normalized_prompt_logprob = None
234
235
236
237
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
238
239

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
240
241
242
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
243
244
245
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

246
        # Embedding (return values)
247
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
248

249
        # Constrained decoding
250
        self.grammar: Optional[Grammar] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
251

252
253
254
        # The number of cached tokens, that were already cached in the KV cache
        self.cached_tokens = 0

255
256
257
258
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

259
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
260
        self.fill_ids = self.origin_input_ids + self.output_ids
261
262
263
264
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
265
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
266

267
    def adjust_max_prefix_ids(self):
268
269
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
270
271
272
273

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
274
275
276
277
278

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

279
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
280
281
282
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
283
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
284

285
        max_prefix_len = max(max_prefix_len, 0)
286
        return self.fill_ids[:max_prefix_len]
287

Liangsheng Yin's avatar
Liangsheng Yin committed
288
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
289
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
290
291
292
293
294
295
296
297
298
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
299
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
300

301
    def get_next_inc_detokenization(self):
302
303
        if self.tokenizer is None:
            return False, ""
304
305
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
306
307
308
309
310

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
311
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
312
313
314
315
316
317
318
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
319
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
320
321

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
322

323
    def check_finished(self):
324
        if self.finished():
325
326
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
327
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
328
329
330
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
331
332
            return

333
        last_token_id = self.output_ids[-1]
334

335
        matched_eos = False
336

337
338
339
        # Check stop token ids
        if self.sampling_params.stop_token_ids:
            matched_eos = last_token_id in self.sampling_params.stop_token_ids
340
341
        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id
342
343
            if self.tokenizer.additional_stop_token_ids:
                matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
344
        if matched_eos and not self.sampling_params.ignore_eos:
345
346
347
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

348
        # Check stop strings
349
350
351
352
353
354
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
355
                if stop_str in tail_str or stop_str in self.decoded_text:
356
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
357
358
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
359
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
360
361
        assert self.grammar is not None and self.tokenizer is not None

Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
364
365
366
367
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
368
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
369
        all_ids = self.tokenizer.encode(all_text)
370
        if not all_ids:
havetc's avatar
havetc committed
371
            logger.warning("Encoded all_text resulted in empty all_ids")
372
373
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
374
        prompt_tokens = len(self.origin_input_ids_unpadded)
375
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
376
            logger.warning("prompt_tokens is larger than encoded all_ids")
377
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
378
379
380

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
381
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
400

401
402
        # update the inner state of the grammar
        self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
403
404
405
406

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
407
408
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
411
                    k = k + 1
                else:
                    break
412
413
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
414
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
415
            self.last_update_decode_tokens = len(self.output_ids) - k
416

Liangsheng Yin's avatar
Liangsheng Yin committed
417
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
418

Lianmin Zheng's avatar
Lianmin Zheng committed
419
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
420
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422


423
424
425
bid = 0


426
@dataclasses.dataclass
427
class ScheduleBatch:
428
429
    """Store all inforamtion of a batch."""

430
    # Request, memory pool, and cache
431
    reqs: List[Req]
432
433
434
    req_to_token_pool: ReqToTokenPool = None
    token_to_kv_pool: BaseTokenToKVPool = None
    tree_cache: BasePrefixCache = None
435
436
437
438

    # For utility
    model_config: ModelConfig = None

Liangsheng Yin's avatar
Liangsheng Yin committed
439
    forward_mode: ForwardMode = None
440
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
441

442
    # Batched arguments to model runner
443
444
445
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
446
    # The output locations of the KV cache
447
    out_cache_loc: torch.Tensor = None
448
449
    output_ids: torch.Tensor = None

450
451
452
    # The sum of all sequence lengths
    seq_lens_sum: int = None

453
    # For processing logprobs
454
    return_logprob: bool = False
455
456
457
458
459
460
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
461
    decoding_reqs: List[Req] = None
462

463
464
465
466
467
468
    # For encoder-decoder
    encoder_cached: Optional[List[bool]] = None
    encoder_lens: Optional[torch.Tensor] = None
    encoder_lens_cpu: Optional[List[int]] = None
    encoder_out_cache_loc: Optional[torch.Tensor] = None

469
470
471
    # Stream
    has_stream: bool = False

472
473
    # Has grammar
    has_grammar: bool = False
474

475
476
477
    # device
    device: str = "cuda"

478
    @classmethod
479
480
    def init_new(
        cls,
481
        reqs: List[Req],
482
483
484
485
486
        req_to_token_pool,
        token_to_kv_pool,
        tree_cache,
        model_config,
    ):
487
488
489
490
491
        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
492
            model_config=model_config,
493
494
            return_logprob=any(req.return_logprob for req in reqs),
            has_stream=any(req.stream for req in reqs),
495
            has_grammar=any(req.grammar for req in reqs),
Zhang, Liangang's avatar
Zhang, Liangang committed
496
            device=req_to_token_pool.device,
Lianmin Zheng's avatar
Lianmin Zheng committed
497
498
        )

499
    def batch_size(self):
500
        return len(self.reqs)
501

Lianmin Zheng's avatar
Lianmin Zheng committed
502
503
504
    def is_empty(self):
        return len(self.reqs) == 0

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
523
524
525
526
527
528
                phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
                logger.error(
                    f"{phase_str} out of memory. Try to lower your batch size.\n"
                    f"Try to allocate {num_tokens} tokens.\n"
                    f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
                )
529
530
531
532
533
534
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
        self.encoder_lens_cpu = []
        self.encoder_cached = []

        for req in self.reqs:
            im = req.image_inputs
            if im is None or im.num_image_tokens is None:
                # No image input
                self.encoder_lens_cpu.append(0)
                self.encoder_cached.append(True)
            else:
                self.encoder_lens_cpu.append(im.num_image_tokens)
                self.encoder_cached.append(
                    self.forward_mode.is_decode()
                    or len(req.prefix_indices) >= im.num_image_tokens
                )

        self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        # Strip encoder infos
        pt = 0
        decoder_out_cache_loc = []
        encoder_out_cache_loc = []
        for i, req in enumerate(self.reqs):
            encoder_len = self.encoder_lens_cpu[i]
            seq_lens[i] -= encoder_len

            if len(req.prefix_indices) < encoder_len:
565
                # NOTE: the encoder part should be considered as a whole
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
                assert len(req.prefix_indices) == 0
                input_ids[i] = input_ids[i][encoder_len:]
                encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
                )
                self.extend_lens[i] -= encoder_len
                self.extend_num_tokens -= encoder_len
            else:
                decoder_out_cache_loc.append(
                    self.out_cache_loc[pt : pt + req.extend_input_len]
                )
                self.prefix_lens[i] -= encoder_len

            pt += req.extend_input_len

        # Reassign
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )

        if not decoder_out_cache_loc:
            self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.out_cache_loc = torch.cat(decoder_out_cache_loc)

        if not encoder_out_cache_loc:
            self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
                self.device, non_blocking=True
            )
        else:
            self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)

        assert len(self.out_cache_loc) == self.extend_num_tokens

    def prepare_for_extend(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
607
608
        self.forward_mode = ForwardMode.EXTEND

609
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
610
        reqs = self.reqs
611
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
612
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614
        seq_lens = []

615
        # Allocate memory
616
        req_pool_indices = self.alloc_req_slots(bs)
617
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
618

619
        pt = 0
620
        for i, req in enumerate(reqs):
621
622
623
624
625
626
627
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

628
            req.req_pool_idx = req_pool_indices[i]
629
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
630
            seq_lens.append(seq_len)
631
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
632

633
            if pre_len > 0:
634
635
                self.req_to_token_pool.write(
                    (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
636
                )
637
638
639
            self.req_to_token_pool.write(
                (req.req_pool_idx, slice(pre_len, seq_len)),
                out_cache_loc[pt : pt + req.extend_input_len],
640
            )
641
642
643
644
645
646
647
648
649
650
651

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
652
            req.is_retracted = False
Lianmin Zheng's avatar
Lianmin Zheng committed
653
654

        # Set fields
655
656
657
658
659
660
661
662
663
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
664

Lianmin Zheng's avatar
Lianmin Zheng committed
665
        self.out_cache_loc = out_cache_loc
666
667

        self.seq_lens_sum = sum(seq_lens)
668
669
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
670
        self.extend_num_tokens = extend_num_tokens
671
672
673
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
674

675
676
677
        if self.model_config.is_encoder_decoder:
            self.prepare_encoder_info_extend(input_ids, seq_lens)

678
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
679
680
681
            self,
            self.model_config.vocab_size,
            global_server_args_dict["disable_penalizer"],
682
        )
683

684
    def mix_with_running(self, running_batch: "ScheduleBatch"):
685
        self.forward_mode = ForwardMode.MIXED
686
        running_bs = running_batch.batch_size()
687
688
689
690
691

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

692
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
693
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
694

695
        self.merge_batch(running_batch)
696
697
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
698
        self.extend_num_tokens += running_bs
699
700

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
701
        self.prefix_lens.extend(
702
703
704
705
706
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
707
708
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
709

710
    def check_decode_mem(self):
711
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
712
        if self.token_to_kv_pool.available_size() >= bs:
713
714
            return True

Mingyi's avatar
Mingyi committed
715
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
716

717
718
719
720
721
722
723
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
724
725

        # TODO(lsyin): improve retraction policy for radix cache
726
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
727
728
729
730
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
731
732
733
734
            reverse=True,
        )

        retracted_reqs = []
735
        seq_lens_cpu = self.seq_lens.cpu().numpy()
736
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
737
738
739
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
740
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
741
742
743
744
745
746
747
748
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

749
            first_iter = False
750
751
752
753
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

754
755
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
756
757
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, : seq_lens_cpu[idx]
758
                ]
759
                self.token_to_kv_pool.free(token_indices)
760
                self.req_to_token_pool.free(req.req_pool_idx)
761
762
763
764
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
765
766
                token_indices = self.req_to_token_pool.req_to_token[
                    req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
767
                ]
768
                self.token_to_kv_pool.free(token_indices)
769
                self.req_to_token_pool.free(req.req_pool_idx)
770
771
772
773
774
775
776
777
778
779
780

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
781

782
            req.prefix_indices = []
783
            req.last_node = None
784
            req.extend_input_len = 0
785
            req.is_retracted = True
Liangsheng Yin's avatar
Liangsheng Yin committed
786
787
788
789

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
790

791
        self.filter_batch(keep_indices=sorted_indices)
792

Liangsheng Yin's avatar
Liangsheng Yin committed
793
794
795
796
797
798
799
800
801
802
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
803

804
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
805
        jump_forward_reqs = []
806
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
807
808

        for i, req in enumerate(self.reqs):
809
810
811
812
            if req.grammar is not None:
                jump_helper = req.grammar.try_jump(req.tokenizer)
                if jump_helper.can_jump():
                    suffix_ids = jump_helper.suffix_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
813
814
815
816
817
                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
818
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
819
820
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
821
822
                        continue

sglang's avatar
sglang committed
823
824
825
                    (
                        jump_forward_str,
                        next_state,
826
                    ) = req.grammar.jump_forward_str_state(jump_helper)
Liangsheng Yin's avatar
Liangsheng Yin committed
827
828
829
830
831
832
833

                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
834

835
836
837
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
838
                    # insert the old request into tree_cache
839
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
840

Liangsheng Yin's avatar
Liangsheng Yin committed
841
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
842
                    if req.image_inputs is not None:
843
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
844
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
845
846
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
847
                    jump_forward_reqs.append(req)
848
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
849

850
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
851

Liangsheng Yin's avatar
Liangsheng Yin committed
852
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
853

854
855
856
857
    def prepare_encoder_info_decode(self):
        # Reset the encoder cached status
        self.encoder_cached = [True] * len(self.reqs)

858
    def prepare_for_decode(self, enable_overlap: bool = False):
Liangsheng Yin's avatar
Liangsheng Yin committed
859
860
        self.forward_mode = ForwardMode.DECODE

861
862
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
863
864
865
866
        if self.sampling_info.penalizer_orchestrator:
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
867
868

        # Alloc mem
869
        bs = len(self.reqs)
870
        self.out_cache_loc = self.alloc_token_slots(bs)
871

872
873
874
875
876
877
        if self.model_config.is_encoder_decoder:
            locs = self.encoder_lens + self.seq_lens
            self.prepare_encoder_info_decode()
        else:
            locs = self.seq_lens

878
879
880
        if enable_overlap:
            # Do not use in-place operations in the overlap mode
            self.req_to_token_pool.write(
881
                (self.req_pool_indices, locs), self.out_cache_loc
882
883
884
885
886
            )
            self.seq_lens = self.seq_lens + 1
        else:
            # A faster in-place version
            self.req_to_token_pool.write(
887
                (self.req_pool_indices, locs), self.out_cache_loc
888
889
            )
            self.seq_lens.add_(1)
890
        self.seq_lens_sum += bs
Lianmin Zheng's avatar
Lianmin Zheng committed
891

892
893
    def filter_batch(
        self,
894
        being_chunked_req: Optional[Req] = None,
895
896
897
898
899
900
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
Chayenne's avatar
Chayenne committed
901
                if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
902
903
904
            ]

        if keep_indices is None or len(keep_indices) == 0:
905
906
907
908
            # Filter out all requests
            self.reqs = []
            return

909
        if len(keep_indices) == len(self.reqs):
910
911
912
            # No need to filter
            return

913
914
915
916
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = self.encoder_lens[keep_indices]
            self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]

917
        self.reqs = [self.reqs[i] for i in keep_indices]
918
919
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
920
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
921
        self.req_pool_indices = self.req_pool_indices[new_indices]
922
        self.seq_lens = self.seq_lens[new_indices]
923
        self.out_cache_loc = None
924
        self.seq_lens_sum = self.seq_lens.sum().item()
925
        self.output_ids = self.output_ids[new_indices]
926
        self.return_logprob = any(req.return_logprob for req in self.reqs)
927
        if self.return_logprob:
928
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
929
930
        else:
            self.top_logprobs_nums = None
931

932
        self.has_stream = any(req.stream for req in self.reqs)
933
        self.has_grammar = any(req.grammar for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
934

935
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
936

937
    def merge_batch(self, other: "ScheduleBatch"):
938
939
940
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
941
        self.sampling_info.merge_batch(other.sampling_info)
942

943
944
945
946
947
        # Encoder-decoder infos
        if self.model_config.is_encoder_decoder:
            self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
            self.encoder_lens_cpu.extend(other.encoder_lens_cpu)

Lianmin Zheng's avatar
Lianmin Zheng committed
948
949
950
951
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
952
        self.out_cache_loc = None
953
        self.seq_lens_sum += other.seq_lens_sum
954
955
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
956
957
958
959
960
961
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
962
        self.reqs.extend(other.reqs)
963

964
        self.return_logprob = self.return_logprob or other.return_logprob
965
        self.has_stream = self.has_stream or other.has_stream
966
        self.has_grammar = self.has_grammar or other.has_grammar
967
968
969

    def get_model_worker_batch(self):
        if self.forward_mode.is_decode():
970
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
971
972
973
974
975
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens

976
977
        if self.has_grammar:
            self.sampling_info.grammars = [req.grammar for req in self.reqs]
978
        else:
979
            self.sampling_info.grammars = None
980

981
982
983
        global bid
        bid += 1

984
        return ModelWorkerBatch(
985
            bid=bid,
986
987
988
989
990
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
991
            seq_lens_sum=self.seq_lens_sum,
992
            req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
993
994
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
995
            extend_num_tokens=self.extend_num_tokens,
996
997
998
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
999
1000
1001
1002
1003
            image_inputs=[r.image_inputs for r in self.reqs],
            encoder_cached=self.encoder_cached,
            encoder_lens=self.encoder_lens,
            encoder_lens_cpu=self.encoder_lens_cpu,
            encoder_out_cache_loc=self.encoder_out_cache_loc,
1004
            lora_paths=[req.lora_path for req in self.reqs],
1005
1006
1007
            sampling_info=self.sampling_info,
        )

1008
    def copy(self):
1009
        # Only contain fields that will be used by process_batch_result
1010
1011
        return ScheduleBatch(
            reqs=self.reqs,
1012
            model_config=self.model_config,
1013
            forward_mode=self.forward_mode,
1014
1015
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
1016
            decoding_reqs=self.decoding_reqs,
1017
1018
1019
1020
1021
1022
1023
1024
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

Chayenne's avatar
Chayenne committed
1025

1026
@dataclasses.dataclass
1027
class ModelWorkerBatch:
1028
1029
    # The batch id
    bid: int
1030
1031
1032
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
1033
    input_ids: torch.Tensor
1034
1035
1036
1037
1038
1039
1040
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

1041
1042
1043
    # The sum of all sequence lengths
    seq_lens_sum: int

1044
1045
1046
    # The memory pool operation records
    req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]

1047
1048
1049
1050
1051
    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

    # For extend
1052
    extend_num_tokens: Optional[int]
1053
1054
1055
1056
1057
1058
1059
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

1060
1061
1062
1063
1064
1065
    # For encoder-decoder
    encoder_cached: Optional[List[bool]]
    encoder_lens: Optional[torch.Tensor]
    encoder_lens_cpu: Optional[List[int]]
    encoder_out_cache_loc: Optional[torch.Tensor]

1066
1067
1068
1069
1070
    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
1071
1072

    def copy(self):
1073
        return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084

    def to(self, device: str):
        self.input_ids = self.input_ids.to(device, non_blocking=True)
        self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
        self.seq_lens = self.seq_lens.to(device, non_blocking=True)
        self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
        self.req_to_token_pool_records = [
            (x, y.to(device, non_blocking=True))
            for x, y in self.req_to_token_pool_records
        ]
        self.sampling_info.to(device)