tree_style_proposer.py 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import List, Optional, Set, Tuple, Any, Dict

import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (ExecuteModelRequest,
                           SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.util import sampler_output_to_torch


class TreeStyleProposer(SpeculativeProposer):
    """Helper class which separates out sequences which would exceed the max
    model length when speculated upon.

    This allows combinations of models such as JackFram/llama-68m draft with
    meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
    2048 while Llama2-13b has max_position_embeddings of 4096.

    We treat the sequences which exceed the proposal draft model length as
    "non-spec sequences". Essentially they skip the draft model and go through
    normal decoding in the target model.

    Currently, only proposal_lens of 0 and k are supported, where k is a global
    batch proposal length. In the future vLLM should support per-sequence
    proposal lengths.
    """

    def __init__(
        self,
        worker: ProposerWorkerBase,
        device: str,
        vocab_size: int,
        tree_buffers: Dict[str, Any], 
        max_proposal_len: Optional[int] = None,
    ):
        self._worker = worker
        self._device = device
        self.tree_buffers = tree_buffers
        self.max_proposal_len = max_proposal_len
        self._vocab_size = vocab_size

    def get_spec_proposals(
        self,
        execute_model_req: ExecuteModelRequest,
        seq_ids_with_bonus_token_in_last_step: Set[int],
    ) -> SpeculativeProposals:
        """Get speculative proposals given the input batch.

        Sequences which would exceed the max model length are skipped during
        speculation.
        """
        #proposal_len = execute_model_req.num_lookahead_slots
        proposal_len = self.tree_buffers["tree_indices"].shape[0]
        seq_group_metadata_list = execute_model_req.seq_group_metadata_list

        # Split speculative- and non-speculative- sequences.
        (
            proposal_lens,
            nonzero_proposal_len_seqs,
            nonzero_proposal_len_indices,
        ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)

        if nonzero_proposal_len_seqs:
            # Speculate tokens using the draft worker for the speculative
            # sequences.
            # If sampler_transposed is true, then maybe_sampler_output's
            # token_ids is like [batch] format in proposal_len size list,
            # while if it is false, the format would be [proposal_len]
            # in batch size list
            hidden_states = execute_model_req.previous_hidden_states
            if hidden_states is not None:
                hidden_states.prune(nonzero_proposal_len_seqs)
            logits = execute_model_req.previous_logits
            if logits is not None:
                logits.prune(nonzero_proposal_len_seqs)

            nonzero_execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=nonzero_proposal_len_seqs,
                num_lookahead_slots=proposal_len,
                previous_hidden_states=hidden_states,
                previous_logits=logits,
            )
            maybe_sampler_output, transposed = self._worker.sampler_output(
                execute_model_req=nonzero_execute_model_req,
                sample_len=proposal_len,
                seq_ids_with_bonus_token_in_last_step=\
                    seq_ids_with_bonus_token_in_last_step,
            )
            (
                proposal_lens,
                maybe_sampler_output,
                nonzero_proposal_len_indices,
            ) = self._remove_no_proposal_seqs(proposal_lens,
                                              maybe_sampler_output,
                                              nonzero_proposal_len_indices,
                                              transposed)
        else:
            # If no sequences can be speculated, set sampler output to None.
            maybe_sampler_output = None
            transposed = False

        # Combine speculative- and non-speculative sequences into the same
        # representation.
        proposal_tokens, proposal_probs, proposal_lens, cart_candidates, tree_attn_masks = self._merge_outputs(
            batch_size=len(seq_group_metadata_list),
            proposal_len=proposal_len,
            maybe_sampler_output=maybe_sampler_output,
            proposal_lens=proposal_lens,
            nonzero_proposal_len_indices=nonzero_proposal_len_indices,
            sampler_transposed=transposed,
        )

        tree_position_ids_list = []
        for seq_group_metadata in seq_group_metadata_list:
            seq_data = next(iter(seq_group_metadata.seq_data.values()))
118
            if seq_data.get_first_step_flag():
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                seq_len = seq_data.get_len() -1
            else:
                seq_len = seq_data.get_len()
            tree_position_ids = self.tree_buffers['tree_position_ids'] + seq_len
            tree_position_ids_list.append(tree_position_ids)
        tree_position_ids = torch.stack(tree_position_ids_list, dim=0).reshape(-1, 1)

        proposals = SpeculativeProposals(
            proposal_token_ids=proposal_tokens,
            proposal_probs=proposal_probs,
            proposal_lens=proposal_lens,
            no_proposals=maybe_sampler_output is None,
            cart_candidates=cart_candidates,
            retrieve_indices=self.tree_buffers['retrieve_indices'],
            tree_attn_masks=tree_attn_masks,
            tree_position_ids=tree_position_ids)

        return proposals

    def _split_by_proposal_len(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_len: int,
    ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
        """Split sequences by two groups:
        1. Sequences with non-zero proposal length.
        2. Sequences with zero proposal length (due to disabled speculation
        or exceed the maximum model length).
        """

        proposal_lens: List[int] = []
        nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
        nonzero_proposal_len_indices: List[int] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            # The speculative decoding for this request has been disabled
            # (e.g. due to high traffic).
            if seq_group_metadata.num_speculative_tokens == 0:
                proposal_lens.append(0)
                continue

            seq_data = next(iter(seq_group_metadata.seq_data.values()))
            seq_len = seq_data.get_len()

            # Currently only proposal lens of 0 or the global batch proposal len
            # are supported.
            # If max_proposal_len is defined, then we shall no exceed this
            # quota for nonzero_proposal
            new_k = 0
            if (self.max_proposal_len is None
                    or seq_len + proposal_len < self.max_proposal_len):
                new_k = proposal_len
                nonzero_proposal_len_seqs.append(seq_group_metadata)
                nonzero_proposal_len_indices.append(i)
            proposal_lens.append(new_k)
            seq_group_metadata.num_speculative_tokens = new_k

        return (
            proposal_lens,
            nonzero_proposal_len_seqs,
            nonzero_proposal_len_indices,
        )

    @staticmethod
    def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output,
                                 nonzero_proposal_len_indices, transposed):
        """Remove sequences from nonzero_proposal_len_indices and reset
        their proposal_len to 0 the draft worker does not provide a proposal
        (maybe_sampler_output=None). This can avoid scoring overheads.
        """

        # If maybe_sampler_output is None, then the draft worker did not
        # provide a proposal for any sequence and thus no action needed.
        # Also we do not support transposed maybe_sampler_output for now
        # because it seems not straightforward for draft workers outputting
        # transposed sampler outputs to handle the case of no proposal.
        if maybe_sampler_output is None or transposed:
            return (proposal_lens, maybe_sampler_output,
                    nonzero_proposal_len_indices)

        new_proposal_lens: List[int] = []
        new_nonzero_proposal_len_indices: List[int] = []
        new_maybe_sampler_output: List[SamplerOutput] = []
        nonzero_proposal_len_idx_ptr = 0
        seq_idx = 0
        while seq_idx < len(
                proposal_lens) and nonzero_proposal_len_idx_ptr < len(
                    nonzero_proposal_len_indices):
            if seq_idx < nonzero_proposal_len_indices[
                    nonzero_proposal_len_idx_ptr]:
                # Sequence is not in the original nonzero_proposal_len_indices,
                # meaning that it has a proposal length of 0 before sending to
                # the draft worker.
                assert proposal_lens[seq_idx] == 0
                new_proposal_lens.append(0)
            else:
                # Sequence is in the original nonzero_proposal_len_indices
                if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None:
                    # but does not have a proposal from the draft worker.
                    new_proposal_lens.append(0)
                else:
                    # and has a proposal from the draft worker. Add it to the
                    # new nonzero proposal list and keep the sampler output.
                    new_proposal_lens.append(proposal_lens[seq_idx])
                    new_nonzero_proposal_len_indices.append(seq_idx)
                    new_maybe_sampler_output.append(
                        maybe_sampler_output[nonzero_proposal_len_idx_ptr])
                nonzero_proposal_len_idx_ptr += 1
            seq_idx += 1

        # The remaining sequences should have proposal length of 0.
        new_proposal_lens.extend(proposal_lens[seq_idx:])

        # We assume sampler_output will not be a list of all Nones.
        # In this case this function should not be called.
        assert new_maybe_sampler_output
        return (new_proposal_lens, new_maybe_sampler_output,
                new_nonzero_proposal_len_indices)

    def _merge_outputs(
        self,
        batch_size: int,
        proposal_len: int,
        maybe_sampler_output: Optional[List[SamplerOutput]],
        proposal_lens: List[int],
        nonzero_proposal_len_indices: List[int],
        sampler_transposed: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """After speculations are produced, merge the speculation results with
        the skipped sequences.
        """
        retrieve_indices = self.tree_buffers["retrieve_indices"]
        if maybe_sampler_output is None:
            # If no speculative tokens, the sampler output will be None.
            # In this case we return empty proposals.
            proposal_tokens = torch.tensor(-1,
                                           dtype=torch.long,
                                           device=self._device).expand(
                                               batch_size, proposal_len)
            proposal_probs = torch.tensor(0,
                                          dtype=torch.float32,
                                          device=self._device).expand(
                                              batch_size, proposal_len,
                                              self._vocab_size)
            proposal_lens_tensor = torch.tensor(0,
                                                dtype=torch.long,
                                                device=self._device).expand(
                                                    len(proposal_lens))
            
            cart_candidates_tensor = torch.tensor(0,
                                                dtype=torch.long,
                                                device=self._device).expand(
                                                    batch_size, retrieve_indices.shape[0],
                                                    retrieve_indices.shape[1])
            
            tree_attn_masks_tensor = torch.tensor(0,
                                                dtype=torch.int32,
                                                device=self._device).expand(
                                                    batch_size, self.tree_buffers["tree_attn_masks"].shape[0],
                                                    self.tree_buffers["tree_attn_masks"].shape[1])
            
            return proposal_tokens, proposal_probs, proposal_lens_tensor, cart_candidates_tensor, tree_attn_masks_tensor

        sampler_output = maybe_sampler_output
        proposal_tokens, proposal_probs, _, _, cart_candidates, tree_attn_masks = sampler_output_to_torch(
            sampler_output, sampler_transposed)

        # Now, reformat the output GPU tensors such that each sequence has
        # a proposal. the proposal can be empty, e.g. [-1, -1, -1]

        entire_proposal_tokens = proposal_tokens.new_full(
            size=(batch_size, *proposal_tokens.shape[1:]),
            fill_value=-1,
        )
        entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
        entire_proposal_probs = None

        proposal_tokens, proposal_probs = (
            entire_proposal_tokens,
            entire_proposal_probs,
        )

        proposal_lens_tensor = torch.zeros(batch_size,
                                           dtype=torch.long,
                                           device=self._device)
        proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len

        entire_cart_candidates = cart_candidates.new_zeros(
            batch_size,
            *cart_candidates.shape[1:],
        )
        entire_cart_candidates[nonzero_proposal_len_indices] = cart_candidates

        entire_tree_attn_masks = tree_attn_masks.new_zeros(
            batch_size,
            *tree_attn_masks.shape[1:],
        )
        entire_tree_attn_masks[nonzero_proposal_len_indices] = tree_attn_masks
        entire_tree_attn_masks = entire_tree_attn_masks.reshape(-1, tree_attn_masks.shape[-1])

        return proposal_tokens, proposal_probs, proposal_lens_tensor, entire_cart_candidates, entire_tree_attn_masks