decode_schedule_batch_mixin.py 6.33 KB
Newer Older
Byron Hsu's avatar
Byron Hsu committed
1
2
3
from __future__ import annotations

import logging
4
from http import HTTPStatus
Byron Hsu's avatar
Byron Hsu committed
5
6
7
8
from typing import TYPE_CHECKING

import torch

9
from sglang.srt.disaggregation.utils import prepare_abort
10
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
Byron Hsu's avatar
Byron Hsu committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from sglang.srt.configs.model_config import ModelConfig
    from sglang.srt.managers.schedule_batch import ScheduleBatch
    from sglang.srt.server_args import ServerArgs


class ScheduleBatchDisaggregationDecodeMixin:

    def prepare_for_prebuilt_extend(self: ScheduleBatch):
        """
        Prepare a prebuilt extend by populate metadata
        Adapted from .prepare_for_extend().
        """

        self.forward_mode = ForwardMode.EXTEND
        reqs = self.reqs
        input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
        extend_num_tokens = sum(len(ids) for ids in input_ids)
        seq_lens = []
        pre_lens = []
        req_pool_indices = []

        # Pre-calculate total size
        total_size = sum(req.extend_input_len for req in reqs)
        out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)

        # Fill the tensor in one pass
        offset = 0
        for i, req in enumerate(reqs):
            req_pool_indices.append(req.req_pool_idx)

            chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
                : req.extend_input_len
            ]
            assert (
                offset + req.extend_input_len <= total_size
            ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
            out_cache_loc[offset : offset + req.extend_input_len] = chunk
            offset += req.extend_input_len

            pre_len = len(req.prefix_indices)
            seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
            seq_lens.append(seq_len)
            if len(req.output_ids) == 0:
                assert (
                    seq_len - pre_len == req.extend_input_len
                ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"

            req.cached_tokens += pre_len - req.already_computed
            req.already_computed = seq_len
            req.is_retracted = False
            pre_lens.append(pre_len)
            req.extend_logprob_start_len = 0

        extend_input_logprob_token_ids = None

        # Set fields
        self.input_ids = torch.tensor(
            sum(input_ids, []), dtype=torch.int32, device=self.device
        )
        self.req_pool_indices = torch.tensor(
            req_pool_indices, dtype=torch.int64, device=self.device
        )
        self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
        self.out_cache_loc = out_cache_loc
        self.seq_lens_sum = sum(seq_lens)
81
82
83
84
85

        if self.return_logprob:
            self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
            self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]

Byron Hsu's avatar
Byron Hsu committed
86
87
88
89
90
        self.extend_num_tokens = extend_num_tokens
        self.prefix_lens = [len(r.prefix_indices) for r in reqs]
        self.extend_lens = [r.extend_input_len for r in reqs]
        self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
        self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
91
        self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
Byron Hsu's avatar
Byron Hsu committed
92
93
94
95
96
97
98
99
100
101
102
103
104

        # Build sampling info
        self.sampling_info = SamplingBatchInfo.from_schedule_batch(
            self,
            self.model_config.vocab_size,
        )

    def process_prebuilt_extend(
        self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
    ):
        """Assign the buffered last input id to schedule batch"""
        self.output_ids = []
        for req in self.reqs:
105
            self.output_ids.append(req.output_ids[-1])
Byron Hsu's avatar
Byron Hsu committed
106
            self.tree_cache.cache_unfinished_req(req)
107
            if req.grammar is not None:
108
109
110
111
112
113
114
115
116
117
118
                # FIXME: this try-except block is for handling unexpected xgrammar issue.
                try:
                    req.grammar.accept_token(req.output_ids[-1])
                except ValueError as e:
                    # Grammar accept_token can raise ValueError if the token is not in the grammar.
                    # This can happen if the grammar is not set correctly or the token is invalid.
                    error_message = f"Grammar accept_token failed for req {req.rid} with token {req.output_ids[-1]}: {e}"
                    self.tree_cache.cache_finished_req(req)
                    prepare_abort(
                        req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
                    )
119
                req.grammar.finished = req.finished()
Byron Hsu's avatar
Byron Hsu committed
120
        self.output_ids = torch.tensor(self.output_ids, device=self.device)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

        # Simulate the eagle run. We add mock data to hidden states for the
        # ease of implementation now meaning the first token will have acc rate
        # of 0.
        if not self.spec_algorithm.is_none():

            b = len(self.reqs)
            topk_p = torch.arange(
                b * server_args.speculative_eagle_topk,
                0,
                -1,
                device=self.device,
                dtype=torch.float32,
            )
            topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
            topk_p /= b * server_args.speculative_eagle_topk
            topk_index = torch.arange(
                b * server_args.speculative_eagle_topk, device=self.device
            )
            topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)

142
143
144
            hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
            hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)

145
146
147
148
149
150
            # local import to avoid circular import
            from sglang.srt.speculative.eagle_utils import EagleDraftInput

            spec_info = EagleDraftInput(
                topk_p=topk_p,
                topk_index=topk_index,
151
                hidden_states=hidden_states,
152
153
154
155
156
                verified_id=self.output_ids,
            )
            spec_info.prepare_for_extend(self)
            spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
            self.spec_info = spec_info