schedule_batch.py 33.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
18
19
20
21
22
23
24
25
26
27
28
"""
Store information about requests and batches.

The following is the flow of data structures for a batch:

ScheduleBatch -> ModelWorkerBatch -> ForwardBatch

- ScheduleBatch is managed by `scheduler.py::Scheduler`.
  It contains high-level scheduling data. Most of the data is on the CPU.
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
  It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
Lianmin Zheng's avatar
Lianmin Zheng committed
29

Ying Sheng's avatar
Ying Sheng committed
30
import logging
31
from dataclasses import dataclass
32
from typing import List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34

import torch
35

Liangsheng Yin's avatar
Liangsheng Yin committed
36
from sglang.global_config import global_config
37
38
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
39
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
40
from sglang.srt.mem_cache.chunk_cache import ChunkCache
41
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
42
from sglang.srt.model_executor.forward_batch_info import ForwardMode
43
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
44
from sglang.srt.sampling.sampling_params import SamplingParams
45
from sglang.srt.server_args import ServerArgs
Liangsheng Yin's avatar
Liangsheng Yin committed
46
47

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
Lianmin Zheng's avatar
Lianmin Zheng committed
48

49
50
# Put some global args for easy access
global_server_args_dict = {
51
52
53
    "attention_backend": ServerArgs.attention_backend,
    "sampling_backend": ServerArgs.sampling_backend,
    "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
Ke Bao's avatar
Ke Bao committed
54
    "disable_mla": ServerArgs.disable_mla,
55
    "torchao_config": ServerArgs.torchao_config,
56
    "disable_nan_detection": ServerArgs.disable_nan_detection,
57
58
}

Lianmin Zheng's avatar
Lianmin Zheng committed
59

Ying Sheng's avatar
Ying Sheng committed
60
61
62
logger = logging.getLogger(__name__)


63
64
65
class BaseFinishReason:
    def __init__(self, is_error: bool = False):
        self.is_error = is_error
Lianmin Zheng's avatar
Lianmin Zheng committed
66

67
    def to_json(self):
68
        raise NotImplementedError()
69
70
71


class FINISH_MATCHED_TOKEN(BaseFinishReason):
Mingyi's avatar
Mingyi committed
72
    def __init__(self, matched: Union[int, List[int]]):
73
74
75
        super().__init__()
        self.matched = matched

76
77
78
79
80
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
81
82


83
84
class FINISH_MATCHED_STR(BaseFinishReason):
    def __init__(self, matched: str):
85
        super().__init__()
86
        self.matched = matched
87

88
89
90
91
92
    def to_json(self):
        return {
            "type": "stop",  # to match OpenAI API's return value
            "matched": self.matched,
        }
93
94


95
96
class FINISH_LENGTH(BaseFinishReason):
    def __init__(self, length: int):
97
        super().__init__()
98
        self.length = length
99

100
101
102
103
104
    def to_json(self):
        return {
            "type": "length",  # to match OpenAI API's return value
            "length": self.length,
        }
105
106
107
108
109
110


class FINISH_ABORT(BaseFinishReason):
    def __init__(self):
        super().__init__(is_error=True)

111
112
113
114
    def to_json(self):
        return {
            "type": "abort",
        }
115

Lianmin Zheng's avatar
Lianmin Zheng committed
116

Liangsheng Yin's avatar
Liangsheng Yin committed
117
118
@dataclass
class ImageInputs:
119
120
    """The image related inputs."""

Liangsheng Yin's avatar
Liangsheng Yin committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    pixel_values: torch.Tensor
    image_hash: int
    image_sizes: Optional[list] = None
    image_offsets: Optional[list] = None
    pad_values: Optional[list] = None
    modalities: Optional[list] = None

    image_embeds: Optional[List[torch.Tensor]] = None
    aspect_ratio_ids: Optional[List[torch.Tensor]] = None
    aspect_ratio_mask: Optional[List[torch.Tensor]] = None

    @staticmethod
    def from_dict(obj, vocab_size):
        # Use image hash as fake token_ids, which is then used for prefix matching
        ret = ImageInputs(
            pixel_values=obj["pixel_values"],
            image_hash=hash(tuple(obj["image_hashes"])),
        )
        image_hash = ret.image_hash
        ret.pad_values = [
            (image_hash) % vocab_size,
            (image_hash >> 16) % vocab_size,
            (image_hash >> 32) % vocab_size,
            (image_hash >> 64) % vocab_size,
        ]
        ret.image_sizes = obj["image_sizes"]
        # Only when pixel values is not None we have modalities
148
        ret.modalities = obj["modalities"] or ["image"]
Liangsheng Yin's avatar
Liangsheng Yin committed
149
150
151
        return ret


Lianmin Zheng's avatar
Lianmin Zheng committed
152
class Req:
153
    """The input and output status of a request."""
154

155
156
157
158
159
    def __init__(
        self,
        rid: str,
        origin_input_text: str,
        origin_input_ids: Tuple[int],
160
        sampling_params: SamplingParams,
161
162
        lora_path: Optional[str] = None,
    ):
163
        # Input and output info
Lianmin Zheng's avatar
Lianmin Zheng committed
164
        self.rid = rid
Liangsheng Yin's avatar
Liangsheng Yin committed
165
        self.origin_input_text = origin_input_text
Liangsheng Yin's avatar
Liangsheng Yin committed
166
        self.origin_input_ids_unpadded = origin_input_ids  # Before image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
167
        self.origin_input_ids = origin_input_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
168
        self.output_ids = []  # Each decode stage's output ids
169
        self.fill_ids = None  # fill_ids = origin_input_ids + output_ids
170
171

        self.sampling_params = sampling_params
172
        self.lora_path = lora_path
Liangsheng Yin's avatar
Liangsheng Yin committed
173

174
175
176
        # Memory info
        self.req_pool_idx = None

177
178
179
        # Check finish
        self.tokenizer = None
        self.finished_reason = None
180
        self.stream = False
181

182
        # For incremental decoding
183
184
185
186
187
188
189
190
        # ----- | --------- read_ids -------|
        # ----- |   surr_ids  |
        # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
        # ----- ^ ----------- ^ ----------- ^
        # ----- 1 ----------- 2 ----------- 3
        # 1: surr_offset
        # 2: read_offset
        # 3: last token
191
        self.vid = 0  # version id to sync decode status with in detokenizer_manager
Liangsheng Yin's avatar
Liangsheng Yin committed
192
193
194
        self.decoded_text = ""
        self.surr_offset = None  # Surrounding offset to defeat the cleanup algorithm
        self.read_offset = None
195

196
197
198
        # The number of decoded tokens for token usage report. Note that
        # this does not include the jump forward tokens.
        self.completion_tokens_wo_jump_forward = 0
199

200
201
202
        # The number of cached tokens, that were already cached in the KV store
        self.cached_tokens = 0

203
        # For vision inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
204
        self.image_inputs: Optional[ImageInputs] = None
205

206
207
        # Prefix info
        self.prefix_indices = []
208
        self.extend_input_len = 0
209
        self.last_node = None
Lianmin Zheng's avatar
Lianmin Zheng committed
210
        self.is_inflight_req = 0
211

212
        # Logprobs (arguments)
213
214
215
        self.return_logprob = False
        self.logprob_start_len = 0
        self.top_logprobs_num = 0
216
217

        # Logprobs (return value)
218
        self.normalized_prompt_logprob = None
219
220
221
222
        self.input_token_logprobs = None
        self.input_top_logprobs = None
        self.output_token_logprobs = []
        self.output_top_logprobs = []
223
224

        # Logprobs (internal values)
Liangsheng Yin's avatar
Liangsheng Yin committed
225
226
227
        # The tokens is prefilled but need to be considered as decode tokens
        # and should be updated for the decode logprobs
        self.last_update_decode_tokens = 0
228
229
230
231
232
        # The relative logprob_start_len in an extend batch
        self.extend_logprob_start_len = 0

        # Embedding
        self.embedding = None
Lianmin Zheng's avatar
Lianmin Zheng committed
233

234
        # Constrained decoding
Liangsheng Yin's avatar
Liangsheng Yin committed
235
236
237
        self.regex_fsm: RegexGuide = None
        self.regex_fsm_state: int = 0
        self.jump_forward_map: JumpForwardMap = None
Liangsheng Yin's avatar
Liangsheng Yin committed
238

239
240
241
242
    # whether request reached finished condition
    def finished(self) -> bool:
        return self.finished_reason is not None

243
    def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
244
        self.fill_ids = self.origin_input_ids + self.output_ids
245
246
247
248
        if tree_cache is not None:
            self.prefix_indices, self.last_node = tree_cache.match_prefix(
                rid=self.rid, key=self.adjust_max_prefix_ids()
            )
249
        self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
250

251
    def adjust_max_prefix_ids(self):
252
253
        self.fill_ids = self.origin_input_ids + self.output_ids
        input_len = len(self.fill_ids)
254
255
256
257

        # FIXME: To work around some bugs in logprob computation, we need to ensure each
        # request has at least one token. Later, we can relax this requirement and use `input_len`.
        max_prefix_len = input_len - 1
Liangsheng Yin's avatar
Liangsheng Yin committed
258
259
260
261
262

        if self.sampling_params.max_new_tokens > 0:
            # Need at least one token to compute logits
            max_prefix_len = min(max_prefix_len, input_len - 1)

263
        if self.return_logprob:
Liangsheng Yin's avatar
Liangsheng Yin committed
264
265
266
            if self.normalized_prompt_logprob is None:
                # Need at least two tokens to compute normalized logprob
                max_prefix_len = min(max_prefix_len, input_len - 2)
267
            max_prefix_len = min(max_prefix_len, self.logprob_start_len)
268

269
        max_prefix_len = max(max_prefix_len, 0)
270
        return self.fill_ids[:max_prefix_len]
271

Liangsheng Yin's avatar
Liangsheng Yin committed
272
    # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
273
    def init_incremental_detokenize(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
274
275
276
277
278
279
280
281
282
        first_iter = self.surr_offset is None or self.read_offset is None

        if first_iter:
            self.read_offset = len(self.origin_input_ids_unpadded)
            self.surr_offset = max(
                self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
            )

        all_ids = self.origin_input_ids_unpadded + self.output_ids
283
        return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
Liangsheng Yin's avatar
Liangsheng Yin committed
284

285
    def get_next_inc_detokenization(self):
286
287
        if self.tokenizer is None:
            return False, ""
288
289
        read_ids, read_offset = self.init_incremental_detokenize()
        surr_ids = read_ids[:read_offset]
Liangsheng Yin's avatar
Liangsheng Yin committed
290
291
292
293
294

        surr_text = self.tokenizer.decode(
            surr_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
Liangsheng Yin's avatar
Liangsheng Yin committed
295
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
296
297
298
299
300
301
302
        new_text = self.tokenizer.decode(
            read_ids,
            skip_special_tokens=self.sampling_params.skip_special_tokens,
            spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
        )

        if len(new_text) > len(surr_text) and not new_text.endswith("�"):
303
            return True, new_text[len(surr_text) :]
Liangsheng Yin's avatar
Liangsheng Yin committed
304
305

        return False, ""
Lianmin Zheng's avatar
Lianmin Zheng committed
306

307
    def check_finished(self):
308
        if self.finished():
309
310
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
311
        if len(self.output_ids) >= self.sampling_params.max_new_tokens:
312
313
314
            self.finished_reason = FINISH_LENGTH(
                length=self.sampling_params.max_new_tokens
            )
315
316
            return

317
        last_token_id = self.output_ids[-1]
318
319
320
321
322
323

        matched_eos = last_token_id in self.sampling_params.stop_token_ids

        if self.tokenizer is not None:
            matched_eos |= last_token_id == self.tokenizer.eos_token_id

324
        if matched_eos and not self.sampling_params.ignore_eos:
325
326
327
            self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
            return

328
329
330
331
332
333
        if len(self.sampling_params.stop_strs) > 0:
            tail_str = self.tokenizer.decode(
                self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
            )

            for stop_str in self.sampling_params.stop_strs:
Liangsheng Yin's avatar
Liangsheng Yin committed
334
                if stop_str in tail_str or stop_str in self.decoded_text:
335
                    self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
336
337
                    return

Liangsheng Yin's avatar
Liangsheng Yin committed
338
    def jump_forward_and_retokenize(self, jump_forward_str, next_state):
Liangsheng Yin's avatar
Liangsheng Yin committed
339
340
341
342
343
344
        if self.origin_input_text is None:
            # Recovering text can only use unpadded ids
            self.origin_input_text = self.tokenizer.decode(
                self.origin_input_ids_unpadded
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
345
        all_text = self.origin_input_text + self.decoded_text + jump_forward_str
Liangsheng Yin's avatar
Liangsheng Yin committed
346
        all_ids = self.tokenizer.encode(all_text)
347
        if not all_ids:
havetc's avatar
havetc committed
348
            logger.warning("Encoded all_text resulted in empty all_ids")
349
350
            return False

Liangsheng Yin's avatar
Liangsheng Yin committed
351
        prompt_tokens = len(self.origin_input_ids_unpadded)
352
        if prompt_tokens > len(all_ids):
havetc's avatar
havetc committed
353
            logger.warning("prompt_tokens is larger than encoded all_ids")
354
            return False
Liangsheng Yin's avatar
Liangsheng Yin committed
355
356
357

        if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
            # TODO(lsyin): fix token fusion
358
            logger.warning(
Liangsheng Yin's avatar
Liangsheng Yin committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
            )
            return False

        old_output_ids = self.output_ids
        self.output_ids = all_ids[prompt_tokens:]
        self.decoded_text = self.decoded_text + jump_forward_str
        self.surr_offset = prompt_tokens
        self.read_offset = len(all_ids)

        # NOTE: A trick to reduce the surrouding tokens decoding overhead
        for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
            surr_text_ = self.tokenizer.decode(
                all_ids[self.read_offset - i : self.read_offset]
            )
            if not surr_text_.endswith("�"):
                self.surr_offset = self.read_offset - i
                break
Liangsheng Yin's avatar
Liangsheng Yin committed
377
378
379
380
381
382

        self.regex_fsm_state = next_state

        if self.return_logprob:
            # For fast-forward part's logprobs
            k = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
383
384
            for i, old_id in enumerate(old_output_ids):
                if old_id == self.output_ids[i]:
Liangsheng Yin's avatar
Liangsheng Yin committed
385
386
387
                    k = k + 1
                else:
                    break
388
389
            self.output_token_logprobs = self.output_token_logprobs[:k]
            self.output_top_logprobs = self.output_top_logprobs[:k]
Liangsheng Yin's avatar
Liangsheng Yin committed
390
            self.logprob_start_len = prompt_tokens + k
Liangsheng Yin's avatar
Liangsheng Yin committed
391
            self.last_update_decode_tokens = len(self.output_ids) - k
392

Liangsheng Yin's avatar
Liangsheng Yin committed
393
        return True
Liangsheng Yin's avatar
Liangsheng Yin committed
394

Lianmin Zheng's avatar
Lianmin Zheng committed
395
    def __repr__(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
396
        return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398


399
400
401
bid = 0


402
@dataclass
403
class ScheduleBatch:
404
405
    """Store all inforamtion of a batch."""

406
    # Request, memory pool, and cache
407
408
    reqs: List[Req]
    req_to_token_pool: ReqToTokenPool
409
    token_to_kv_pool: BaseTokenToKVPool
410
    tree_cache: BasePrefixCache
411

Liangsheng Yin's avatar
Liangsheng Yin committed
412
    forward_mode: ForwardMode = None
413
    sampling_info: SamplingBatchInfo = None
Liangsheng Yin's avatar
Liangsheng Yin committed
414

415
    # Batched arguments to model runner
416
417
418
    input_ids: torch.Tensor = None
    req_pool_indices: torch.Tensor = None
    seq_lens: torch.Tensor = None
419
    out_cache_loc: torch.Tensor = None
420

421
422
    output_ids: torch.Tensor = None

423
    # For processing logprobs
424
    return_logprob: bool = False
425
426
427
428
429
430
431
    top_logprobs_nums: Optional[List[int]] = None

    # For extend and mixed chunekd prefill
    prefix_lens: List[int] = None
    extend_lens: List[int] = None
    extend_num_tokens: int = None
    running_bs: int = None
432
    decoding_reqs: List[Req] = None
433

434
435
436
    # Stream
    has_stream: bool = False

Zhang, Liangang's avatar
Zhang, Liangang committed
437
438
439
    # device
    device: str = "cuda"

440
441
442
    # Has regex
    has_regex: bool = False

443
444
    @classmethod
    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
445
        return_logprob = any(req.return_logprob for req in reqs)
446
        has_stream = any(req.stream for req in reqs)
447
        has_regex = any(req.regex_fsm for req in reqs)
448
449
450
451
452
453

        return cls(
            reqs=reqs,
            req_to_token_pool=req_to_token_pool,
            token_to_kv_pool=token_to_kv_pool,
            tree_cache=tree_cache,
454
            return_logprob=return_logprob,
455
            has_stream=has_stream,
Zhang, Liangang's avatar
Zhang, Liangang committed
456
            device=req_to_token_pool.device,
457
            has_regex=has_regex,
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
        )

460
    def batch_size(self):
461
        return len(self.reqs)
462

Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
465
    def is_empty(self):
        return len(self.reqs) == 0

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def alloc_req_slots(self, num_reqs):
        req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
        if req_pool_indices is None:
            raise RuntimeError(
                "Out of memory. "
                "Please set a smaller number for `--max-running-requests`."
            )
        return req_pool_indices

    def alloc_token_slots(self, num_tokens: int):
        out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

        if out_cache_loc is None:
            if self.tree_cache is not None:
                self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
                out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)

            if out_cache_loc is None:
                logger.error("Prefill out of memory. Try to lower your batch size.")
                if self.tree_cache is not None:
                    self.tree_cache.pretty_print()
                exit(1)

        return out_cache_loc

491
    def prepare_for_extend(self, vocab_size: int):
Liangsheng Yin's avatar
Liangsheng Yin committed
492
493
        self.forward_mode = ForwardMode.EXTEND

494
        bs = len(self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
495
        reqs = self.reqs
496
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
497
        extend_num_tokens = sum(len(ids) for ids in input_ids)
Lianmin Zheng's avatar
Lianmin Zheng committed
498
499
        seq_lens = []

500
        # Allocate memory
501
        req_pool_indices = self.alloc_req_slots(bs)
502
        out_cache_loc = self.alloc_token_slots(extend_num_tokens)
503

504
        pt = 0
505
        for i, req in enumerate(reqs):
506
507
508
509
510
511
512
            already_computed = (
                req.extend_logprob_start_len + 1 + req.cached_tokens
                if req.extend_logprob_start_len > 0
                else 0
            )
            req.cached_tokens += len(req.prefix_indices) - already_computed

513
            req.req_pool_idx = req_pool_indices[i]
514
            pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
515
            seq_lens.append(seq_len)
516
            assert seq_len - pre_len == req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
517

518
            if pre_len > 0:
519
520
521
                self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
                    req.prefix_indices
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
522

523
            self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
524
                out_cache_loc[pt : pt + req.extend_input_len]
525
            )
526
527
528
529
530
531
532
533
534
535
536

            # Compute the relative logprob_start_len in an extend batch
            if req.logprob_start_len >= pre_len:
                extend_logprob_start_len = min(
                    req.logprob_start_len - pre_len, req.extend_input_len - 1
                )
            else:
                extend_logprob_start_len = req.extend_input_len - 1

            req.extend_logprob_start_len = extend_logprob_start_len
            pt += req.extend_input_len
Lianmin Zheng's avatar
Lianmin Zheng committed
537
538

        # Set fields
539
540
541
542
543
544
545
546
547
        self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
            self.device, non_blocking=True
        )
548

Lianmin Zheng's avatar
Lianmin Zheng committed
549
550
        self.extend_num_tokens = extend_num_tokens
        self.out_cache_loc = out_cache_loc
551
552
553
554
555
        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
Lianmin Zheng's avatar
Lianmin Zheng committed
556

557
558
559
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self, vocab_size, global_server_args_dict["disable_penalizer"]
        )
560

561
    def mix_with_running(self, running_batch: "ScheduleBatch"):
562
        self.forward_mode = ForwardMode.MIXED
563
        running_bs = running_batch.batch_size()
564
565
566
567
568

        for req in running_batch.reqs:
            req.fill_ids = req.origin_input_ids + req.output_ids
            req.extend_input_len = 1

569
        input_ids = torch.cat([self.input_ids, running_batch.input_ids])
570
        out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
571
572
        extend_num_tokens = self.extend_num_tokens + running_bs

573
        self.merge_batch(running_batch)
574
575
576
        self.input_ids = input_ids
        self.out_cache_loc = out_cache_loc
        self.extend_num_tokens = extend_num_tokens
577
578

        # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
579
        self.prefix_lens.extend(
580
581
582
583
584
            [
                len(r.origin_input_ids) + len(r.output_ids) - 1
                for r in running_batch.reqs
            ]
        )
585
586
        self.extend_lens.extend([1] * running_bs)
        self.extend_logprob_start_lens.extend([0] * running_bs)
587

588
    def check_decode_mem(self):
589
        bs = len(self.reqs)
Ying Sheng's avatar
Ying Sheng committed
590
        if self.token_to_kv_pool.available_size() >= bs:
591
592
            return True

Mingyi's avatar
Mingyi committed
593
        self.tree_cache.evict(bs, self.token_to_kv_pool.free)
594

595
596
597
598
599
600
601
        if self.token_to_kv_pool.available_size() >= bs:
            return True

        return False

    def retract_decode(self):
        sorted_indices = [i for i in range(len(self.reqs))]
Liangsheng Yin's avatar
Liangsheng Yin committed
602
603

        # TODO(lsyin): improve retraction policy for radix cache
604
        sorted_indices.sort(
Liangsheng Yin's avatar
Liangsheng Yin committed
605
606
607
608
            key=lambda i: (
                len(self.reqs[i].output_ids),
                -len(self.reqs[i].origin_input_ids),
            ),
609
610
611
612
            reverse=True,
        )

        retracted_reqs = []
613
        seq_lens_cpu = self.seq_lens.cpu().numpy()
614
        first_iter = True
Liangsheng Yin's avatar
Liangsheng Yin committed
615
616
617
        while (
            self.token_to_kv_pool.available_size()
            < len(sorted_indices) * global_config.retract_decode_steps
618
            or first_iter
Liangsheng Yin's avatar
Liangsheng Yin committed
619
620
621
622
623
624
625
626
        ):
            if len(sorted_indices) == 1:
                # Corner case: only one request left
                assert (
                    self.token_to_kv_pool.available_size() > 0
                ), "No space left for only one request"
                break

627
            first_iter = False
628
629
630
631
            idx = sorted_indices.pop()
            req = self.reqs[idx]
            retracted_reqs.append(req)

632
633
            if isinstance(self.tree_cache, ChunkCache):
                # ChunkCache does not have eviction
634
635
636
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    : seq_lens_cpu[idx]
                ]
637
                self.token_to_kv_pool.free(token_indices)
638
                self.req_to_token_pool.free(req.req_pool_idx)
639
640
641
642
                del self.tree_cache.entries[req.rid]
            else:
                # TODO: apply more fine-grained retraction
                last_uncached_pos = len(req.prefix_indices)
643
644
645
                token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                    last_uncached_pos : seq_lens_cpu[idx]
                ]
646
                self.token_to_kv_pool.free(token_indices)
647
                self.req_to_token_pool.free(req.req_pool_idx)
648
649
650
651
652
653
654
655
656
657
658

                # release the last node
                self.tree_cache.dec_lock_ref(req.last_node)

                # NOTE(lsyin): we should use the newly evictable memory instantly.
                residual_size = (
                    len(sorted_indices) * global_config.retract_decode_steps
                    - self.token_to_kv_pool.available_size()
                )
                residual_size = max(0, residual_size)
                self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
Liangsheng Yin's avatar
Liangsheng Yin committed
659

660
            req.prefix_indices = []
661
            req.last_node = None
662
            req.extend_input_len = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
663
664
665
666

            # For incremental logprobs
            req.last_update_decode_tokens = 0
            req.logprob_start_len = 10**9
Liangsheng Yin's avatar
Liangsheng Yin committed
667

668
        self.filter_batch(keep_indices=sorted_indices)
669

Liangsheng Yin's avatar
Liangsheng Yin committed
670
671
672
673
674
675
676
677
678
679
        # Reqs in batch are filtered
        total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
        total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)

        new_estimate_ratio = (
            total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
        ) / total_max_new_tokens
        new_estimate_ratio = min(1.0, new_estimate_ratio)

        return retracted_reqs, new_estimate_ratio
680

681
    def check_for_jump_forward(self, pad_input_ids_func):
Liangsheng Yin's avatar
Liangsheng Yin committed
682
        jump_forward_reqs = []
683
        keep_indices = set(i for i in range(len(self.reqs)))
Liangsheng Yin's avatar
Liangsheng Yin committed
684
685

        for i, req in enumerate(self.reqs):
Liangsheng Yin's avatar
Liangsheng Yin committed
686
            if req.jump_forward_map is not None:
Liangsheng Yin's avatar
Liangsheng Yin committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
                    req.regex_fsm_state
                )
                if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
                    suffix_bytes = []
                    continuation_range = range(0x80, 0xC0)
                    cur_state = req.regex_fsm_state
                    while (
                        len(jump_forward_bytes)
                        and jump_forward_bytes[0][0] in continuation_range
                    ):
                        # continuation bytes
                        byte_edge = jump_forward_bytes.pop(0)
                        suffix_bytes.append(byte_edge[0])
                        cur_state = byte_edge[1]

                    suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
                    suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)

                    # Current ids, for cache and revert
                    cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
                    cur_output_ids = req.output_ids

                    req.output_ids.extend(suffix_ids)
711
                    decode_res, new_text = req.get_next_inc_detokenization()
Liangsheng Yin's avatar
Liangsheng Yin committed
712
713
                    if not decode_res:
                        req.output_ids = cur_output_ids
Liangsheng Yin's avatar
Liangsheng Yin committed
714
715
                        continue

sglang's avatar
sglang committed
716
717
718
719
                    (
                        jump_forward_str,
                        next_state,
                    ) = req.jump_forward_map.jump_forward_symbol(cur_state)
Liangsheng Yin's avatar
Liangsheng Yin committed
720
721
722
723
724
725
726
727
728

                    # Make the incrementally decoded text part of jump_forward_str
                    # so that the UTF-8 will not corrupt
                    jump_forward_str = new_text + jump_forward_str
                    if not req.jump_forward_and_retokenize(
                        jump_forward_str, next_state
                    ):
                        req.output_ids = cur_output_ids
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
729

730
731
732
                    # The decode status has diverged from detokenizer_manager
                    req.vid += 1

Liangsheng Yin's avatar
Liangsheng Yin committed
733
                    # insert the old request into tree_cache
734
                    self.tree_cache.cache_finished_req(req, cur_all_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
735

Liangsheng Yin's avatar
Liangsheng Yin committed
736
                    # re-applying image padding
Liangsheng Yin's avatar
Liangsheng Yin committed
737
                    if req.image_inputs is not None:
738
                        req.origin_input_ids = pad_input_ids_func(
Liangsheng Yin's avatar
Liangsheng Yin committed
739
                            req.origin_input_ids_unpadded, req.image_inputs
Liangsheng Yin's avatar
Liangsheng Yin committed
740
741
                        )

Liangsheng Yin's avatar
Liangsheng Yin committed
742
                    jump_forward_reqs.append(req)
743
                    keep_indices.remove(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
744

745
        self.filter_batch(keep_indices=list(keep_indices))
Liangsheng Yin's avatar
Liangsheng Yin committed
746

Liangsheng Yin's avatar
Liangsheng Yin committed
747
        return jump_forward_reqs
Liangsheng Yin's avatar
Liangsheng Yin committed
748

749
    def prepare_for_decode(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
750
751
        self.forward_mode = ForwardMode.DECODE

752
753
        self.input_ids = self.output_ids
        self.output_ids = None
Lianmin Zheng's avatar
Lianmin Zheng committed
754
755
756
757
        if self.sampling_info.penalizer_orchestrator:
            self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                self.input_ids
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
758
759

        # Alloc mem
760
        bs = len(self.reqs)
761
        self.out_cache_loc = self.alloc_token_slots(bs)
Lianmin Zheng's avatar
Lianmin Zheng committed
762

763
764
765
766
        self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
            self.out_cache_loc
        )
        self.seq_lens.add_(1)
Lianmin Zheng's avatar
Lianmin Zheng committed
767

768
769
770
771
772
773
774
775
776
777
778
779
780
781
    def filter_batch(
        self,
        current_inflight_req: Optional[Req] = None,
        keep_indices: Optional[List[int]] = None,
    ):
        if keep_indices is None:
            keep_indices = [
                i
                for i in range(len(self.reqs))
                if not self.reqs[i].finished()
                and self.reqs[i] is not current_inflight_req
            ]

        if keep_indices is None or len(keep_indices) == 0:
782
783
784
785
            # Filter out all requests
            self.reqs = []
            return

786
        if len(keep_indices) == len(self.reqs):
787
788
789
            # No need to filter
            return

790
        self.reqs = [self.reqs[i] for i in keep_indices]
791
792
        new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
            self.device, non_blocking=True
793
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
794
        self.req_pool_indices = self.req_pool_indices[new_indices]
795
        self.seq_lens = self.seq_lens[new_indices]
796
        self.out_cache_loc = None
797
        self.output_ids = self.output_ids[new_indices]
798
        self.return_logprob = any(req.return_logprob for req in self.reqs)
799
        if self.return_logprob:
800
            self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
801
802
        else:
            self.top_logprobs_nums = None
803

804
        self.has_stream = any(req.stream for req in self.reqs)
805
        self.has_regex = any(req.regex_fsm for req in self.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
806

807
        self.sampling_info.filter_batch(keep_indices, new_indices)
Lianmin Zheng's avatar
Lianmin Zheng committed
808

809
    def merge_batch(self, other: "ScheduleBatch"):
810
811
812
        # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
        # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
        # needs to be called with pre-merged Batch.reqs.
813
        self.sampling_info.merge_batch(other.sampling_info)
814

Lianmin Zheng's avatar
Lianmin Zheng committed
815
816
817
818
        self.req_pool_indices = torch.concat(
            [self.req_pool_indices, other.req_pool_indices]
        )
        self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
819
        self.out_cache_loc = None
820
821
        if self.output_ids is not None:
            self.output_ids = torch.concat([self.output_ids, other.output_ids])
822
823
824
825
826
827
        if self.return_logprob and other.return_logprob:
            self.top_logprobs_nums.extend(other.top_logprobs_nums)
        elif self.return_logprob:
            self.top_logprobs_nums.extend([0] * len(other.reqs))
        elif other.return_logprob:
            self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
828
        self.reqs.extend(other.reqs)
829

830
        self.return_logprob = self.return_logprob or other.return_logprob
831
832
        self.has_stream = self.has_stream or other.has_stream
        self.has_regex = self.has_regex or other.has_regex
833
834
835
836
837
838
839
840
841
842
843
844
845

    def get_model_worker_batch(self):
        if self.forward_mode.is_decode():
            extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
                image_inputs
            ) = None
        else:
            extend_seq_lens = self.extend_lens
            extend_prefix_lens = self.prefix_lens
            extend_logprob_start_lens = self.extend_logprob_start_lens
            image_inputs = [r.image_inputs for r in self.reqs]

        lora_paths = [req.lora_path for req in self.reqs]
846
847
848
849
850
        if self.has_regex:
            self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
            self.sampling_info.regex_fsm_states = [
                req.regex_fsm_state for req in self.reqs
            ]
851
852
        else:
            self.sampling_info.regex_fsms = None
853

854
855
856
        global bid
        bid += 1

857
        return ModelWorkerBatch(
858
            bid=bid,
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            forward_mode=self.forward_mode,
            input_ids=self.input_ids,
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
            extend_seq_lens=extend_seq_lens,
            extend_prefix_lens=extend_prefix_lens,
            extend_logprob_start_lens=extend_logprob_start_lens,
            image_inputs=image_inputs,
            lora_paths=lora_paths,
            sampling_info=self.sampling_info,
        )

874
875
876
877
878
879
880
    def copy(self):
        return ScheduleBatch(
            reqs=self.reqs,
            req_to_token_pool=self.req_to_token_pool,
            token_to_kv_pool=self.token_to_kv_pool,
            tree_cache=self.tree_cache,
            forward_mode=self.forward_mode,
881
882
883
            output_ids=self.output_ids,
            sampling_info=self.sampling_info,
            decoding_reqs=self.decoding_reqs,
884
885
886
887
888
889
890
891
        )

    def __str__(self):
        return (
            f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
            f"#req={(len(self.reqs))})"
        )

892
893
894

@dataclass
class ModelWorkerBatch:
895
896
    # The batch id
    bid: int
897
898
899
    # The forward mode
    forward_mode: ForwardMode
    # The input ids
900
    input_ids: torch.Tensor
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
    # The indices of requests in the req_to_token_pool
    req_pool_indices: torch.Tensor
    # The sequence length
    seq_lens: torch.Tensor
    # The indices of output tokens in the token_to_kv_pool
    out_cache_loc: torch.Tensor

    # For logprob
    return_logprob: bool
    top_logprobs_nums: Optional[List[int]]

    # For extend
    extend_seq_lens: Optional[List[int]]
    extend_prefix_lens: Optional[List[int]]
    extend_logprob_start_lens: Optional[List[int]]

    # For multimodal
    image_inputs: Optional[List[ImageInputs]]

    # For LoRA
    lora_paths: Optional[List[str]]

    # Sampling info
    sampling_info: SamplingBatchInfo
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942

    def copy(self):
        return ModelWorkerBatch(
            bid=self.bid,
            forward_mode=self.forward_mode,
            input_ids=self.input_ids.clone(),
            req_pool_indices=self.req_pool_indices,
            seq_lens=self.seq_lens,
            out_cache_loc=self.out_cache_loc,
            return_logprob=self.return_logprob,
            top_logprobs_nums=self.top_logprobs_nums,
            extend_seq_lens=self.extend_seq_lens,
            extend_prefix_lens=self.extend_prefix_lens,
            extend_logprob_start_lens=self.extend_logprob_start_lens,
            image_inputs=self.image_inputs,
            lora_paths=self.lora_paths,
            sampling_info=self.sampling_info.copy(),
        )