batch_expansion.py 22.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from array import array
4
from itertools import chain, count
5
from typing import Iterator, List, Optional, Tuple
6
7
8

import torch

9
from vllm import SamplingParams
10
from vllm.model_executor.layers.sampler import SamplerOutput
11
12
13
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
                           ExecuteModelRequest, SequenceData,
                           SequenceGroupMetadata, get_all_seq_ids)
14
15
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
16
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
17
18
19
20
21

SeqId = int
TargetSeqId = int
TokenId = int

22
23
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

class BatchExpansionTop1Scorer(SpeculativeScorer):
    """Implements a speculative scorer that uses batch expansion to get
    probabilities of speculative tokens according to the scoring model.

    Batch expansion converts a list of sequences and multiple query positions
    to a new batch of sequences, each with a single query position. This allows
    for MQA-like scoring in speculative decoding without requiring an MQA
    kernel.

    It is strictly less efficient than MQA scoring.

    It only supports scoring the top1 proposal tokens of the proposer, instead
    of topk/tree.
    """

    @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
    def score_proposals(
        self,
43
        execute_model_req: ExecuteModelRequest,
44
45
46
47
48
49
50
51
52
53
54
55
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        """Score the proposed tokens via the scorer model.

        This converts each input sequence to a set of k+1 target sequences. The
        target sequences have the unique continuations to be scored and a
        unique sequence ID that is different from all input sequence ids.

        If a speculative sequence length would exceed the max model length, then
        no speculation is produced for that sequence.

        Args:
56
            execute_model_req: The execution request.
57
58
59
60
61
62
63
64
65
66
            proposals: The speculative proposals to score.
        Returns:
            SpeculativeScores: The scores of each speculative token, along with
                which sequences were ignored during scoring.
        """

        # TODO(cade) perform this on GPU to remove blocking call.
        proposal_lens_list = proposals.proposal_lens.tolist()
        proposal_token_ids_list = proposals.proposal_token_ids.tolist()

67
        # Filter the list to ignore invalid proposals.
68
69
        proposal_token_ids_list_without_skips = [
            proposals for proposals in proposal_token_ids_list
70
            if VLLM_INVALID_TOKEN_ID not in proposals
71
72
        ]

73
74
        (spec_indices, non_spec_indices, target_seq_group_metadata_list,
         num_scoring_tokens) = self._expand_batch(
75
             seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
76
             proposal_token_ids_list=proposal_token_ids_list_without_skips,
77
78
             proposal_lens_list=proposal_lens_list,
         )
79
80

        target_sampler_output = self._scorer_worker.execute_model(
81
            execute_model_req=execute_model_req.clone(
82
                seq_group_metadata_list=target_seq_group_metadata_list))
83
84
        assert len(target_sampler_output) == 1, "expected single-step output"
        target_sampler_output = target_sampler_output[0]
85

86
87
        if not non_spec_indices:
            # All sequence groups in batch have spec decoding enabled
88
            return self._contract_batch_all_spec(
89
90
91
92
93
                target_sampler_output=target_sampler_output,
                proposals=proposals,
            )
        else:
            # Batch has a mix of spec decode enabled and disabled seq groups
94
            return self._contract_batch(
95
                execute_model_req.seq_group_metadata_list,
96
97
98
99
100
101
102
103
                target_sampler_output=target_sampler_output,
                proposals=proposals,
                num_scoring_tokens=num_scoring_tokens,
                non_spec_indices=non_spec_indices,
                spec_indices=spec_indices,
                k=execute_model_req.num_lookahead_slots,
            )

104
105
106
    def _expand_batch(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
107
        proposal_token_ids_list: List[List[TokenId]],
108
109
110
111
112
113
114
115
116
117
118
        proposal_lens_list: List[int],
    ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
        """Given the input sequences and potentially multiple corresponding
        proposal tokens, create a new batch where each sequence has a single
        query token.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
119
120
121
        (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
            split_batch_by_proposal_len(
                seq_group_metadata_list, proposal_lens_list)
122

123
        spec_expanded_seqs = self._create_scoring_model_input(
124
125
126
127
128
129
130
131
            seq_group_metadata_list=spec_seqs,
            proposal_token_ids=proposal_token_ids_list,
            # NOTE: We determine the seq ids in the expanded batch using the
            # full seq_group_metadata_list, instead of only spec_seqs.
            target_seq_ids_iter=self._create_target_seq_id_iterator(
                seq_ids=get_all_seq_ids(seq_group_metadata_list)),
        )

132
133
134
135
        num_scoring_tokens = len(spec_expanded_seqs)
        # Batch speculative and non-speculative (e.g. chunked prefill) requests
        # but make sure order is prefill|decode due to backend requirement.
        target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs
136

137
138
        return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
                num_scoring_tokens)
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    def _contract_non_speculative(
            self, scores: SpeculativeScores,
            seq_group_metadata_list: List[SequenceGroupMetadata],
            non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
            has_prompt_log: bool) -> SpeculativeScores:
        """
            Augment input `scores` with non-speculative requests outputs. 
            This includes decode requests with speculation turned off, as well
            as prefill requests when `enable_chunked_prefill` is set.
            For the latter, prefills are further separated into terminal and 
            non-terminal chunks (from which no token is sampled).
        """
        if not non_spec_indices:
            return scores

        if has_prompt_log:
            # When prompt_logprobs is enabled, prefills yield output token
            # (and respective prob) in the last entry (prompt|out):
            # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
            # With chunked prefill, non-terminal chunks have -1 on each
            # position: they're still picked, but they're discarded later.
            seq_meta = seq_group_metadata_list
            nospec_sizes = torch.tensor([
                seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
                for i in non_spec_indices
            ])
            nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
        else:
            # In this case only sampled tokens are returned, select all.
            nospec_sampled_token_idxs = list(
                range(len(non_spec_outputs.token_ids)))

        scores.token_ids[non_spec_indices, :1] = \
            non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
        scores.probs[non_spec_indices, :1, :] = \
            non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
        scores.logprobs[non_spec_indices, :1, :] = \
            non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
        if scores.hidden_states is not None:
            assert non_spec_outputs.hidden_states is not None
            scores.hidden_states[non_spec_indices, :1, :] = \
                non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
        return scores

184
    def _contract_batch(
185
186
187
188
189
190
            self,
            contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
            target_sampler_output: SamplerOutput,
            proposals: SpeculativeProposals, num_scoring_tokens: int,
            non_spec_indices: List[int], spec_indices: List[int],
            k: int) -> SpeculativeScores:
191
192
193
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.
194

195
196
197
        contracted_bs is the original batch size, and the batch size that the
        target_sampler_output will be contracted to.
        """
198
        contracted_bs = len(contracted_seq_group_metadata_list)
199
        (target_token_ids, target_probs, target_logprobs, target_hidden_states,
200
         non_spec_target_token_ids, non_spec_target_probs,
201
202
         non_spec_target_logprobs,
         non_spec_target_hidden_states) = self._split_scoring_output(
203
204
205
206
             target_sampler_output, num_scoring_tokens)

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
207
208
209
210
        expanded_batch_size, k = proposals.proposal_token_ids.shape

        # The number of tokens in the expanded batch used for speculation is
        # equal to the total expanded batch size minus the number of samples for
211
212
        # non-speculative sequences, prefill chunks with no out tokens included
        non_spec_expanded_bs = len(non_spec_indices)
213
        spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
214

215
216
217
218
219
        target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
        target_probs = target_probs.reshape(*target_token_ids.shape,
                                            self._vocab_size)
        target_logprobs = target_logprobs.reshape(target_probs.shape)

220
221
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
222
                *target_token_ids.shape, target_hidden_states.shape[-1])
223

224
225
226
227
228
        all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
                                               fill_value=-1)
        all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                fill_value=-float("inf"))
229

230
231
232
233
234
235
        if target_sampler_output.hidden_states is not None:
            all_hidden_states = target_hidden_states.new_zeros(
                size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
        else:
            all_hidden_states = None

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        has_prompt_log = any((sg.sampling_params.prompt_logprobs
                              and sg.sampling_params.prompt_logprobs > 0)
                             for sg in contracted_seq_group_metadata_list)
        # When prompt logprobs is enabled, lens of returned tensors go from
        # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
        # We adjust stride accordingly to get the generated tokens and
        # their probs, but pass on prompt_logprobs as is.
        prompt_logprobs = None
        if (not self._scorer_worker.model_runner.disable_logprobs\
            and has_prompt_log):
            prompt_logprobs = [
                o.prompt_logprobs for o in target_sampler_output.outputs
            ]
        elif not has_prompt_log:
            # When prompt logprobs are not to be returned,
            # we can ignore non-terminal chunks (no out token).
            non_spec_indices = [
                idx for idx in non_spec_indices
                if contracted_seq_group_metadata_list[idx].do_sample
            ]

        # "Contract" speculative.
258
259
260
        if spec_indices:
            all_tokens[spec_indices] = target_token_ids
            all_probs[spec_indices] = target_probs
261
            all_logprobs[spec_indices] = target_logprobs
262
263
264
            if all_hidden_states is not None:
                all_hidden_states[spec_indices] = target_hidden_states

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        spec_scores = SpeculativeScores(probs=all_probs,
                                        token_ids=all_tokens,
                                        logprobs=all_logprobs,
                                        hidden_states=all_hidden_states,
                                        prompt_logprobs=prompt_logprobs)

        non_spec_outputs = SpeculativeScores(
            probs=non_spec_target_probs,
            token_ids=non_spec_target_token_ids,
            logprobs=non_spec_target_logprobs,
            hidden_states=non_spec_target_hidden_states)
        # Contract remaining nonspec entries based on non_spec_indices, if any.
        return self._contract_non_speculative(
            spec_scores, contracted_seq_group_metadata_list, non_spec_indices,
            non_spec_outputs, has_prompt_log)
280

281
282
283
284
    def _contract_batch_all_spec(
        self,
        target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals,
285
    ) -> SpeculativeScores:
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        It assumes all sequences in the batch were previously expanded.
        """

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
        contracted_bs, k = proposals.proposal_token_ids.shape

        # Reshape tensors to original batch size
        target_token_ids = target_sampler_output.sampled_token_ids.reshape(
            contracted_bs, k + 1)
        target_probs = target_sampler_output.sampled_token_probs.reshape(
            *target_token_ids.shape, self._vocab_size)
        target_logprobs = target_sampler_output.logprobs.reshape(
            target_probs.shape)
        target_hidden_states = target_sampler_output.hidden_states
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

309
310
311
312
313
        return SpeculativeScores(probs=target_probs,
                                 token_ids=target_token_ids,
                                 logprobs=target_logprobs,
                                 hidden_states=target_hidden_states,
                                 prompt_logprobs=None)
314

315
    def _create_scoring_model_input(
316
317
318
319
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
        target_seq_ids_iter: Iterator[TargetSeqId],
320
321
322
    ) -> List[SequenceGroupMetadata]:
        """Given the original input sequences and proposed tokens from the draft
        model, create a list of target sequences that can be used for scoring.
323
324
325
326

        target_seq_ids_iter provides sequence ids for the expanded batch,
        fulfilling the requirement that no seq id in the expanded batch is equal
        to the seq id in the original batch.
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        """

        if not seq_group_metadata_list:
            return []

        target_seq_group_metadata = list(
            chain.from_iterable(
                self._create_target_seq_group_metadata(
                    seq_group_metadata,
                    proposal_token_ids,
                    i,
                    target_seq_ids_iter,
                ) for i, seq_group_metadata in enumerate(
                    seq_group_metadata_list)))

        return target_seq_group_metadata

    def _create_target_seq_group_metadata(
        self,
        input_seq_group_metadata: SequenceGroupMetadata,
347
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        batch_index: int,
        target_seq_ids_iter: Iterator[TargetSeqId],
    ) -> List[SequenceGroupMetadata]:
        """Given an input sequence group metadata and a list of draft tokens,
        create a list of target SequenceGroupMetadata, one for each
        token id that needs to be scored.

        Naive speculative decoding requires K target model scores, one for each
        draft model token. However one can add a bonus token such that if each
        token is accepted, then a final token may be sampled from the model.
        This function creates K+1 target SequenceGroupMetadata to take
        advantage of the bonus token.
        """
        assert len(input_seq_group_metadata.seq_data) == 1, (
            "Beam search "
            "not supported in speculative decoding")
        input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))

        token_ids_to_score = self._get_token_ids_to_score(
            proposal_token_ids[batch_index])

369
        sampling_params = input_seq_group_metadata.sampling_params
370
        target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
371
        for i, token_ids in enumerate(token_ids_to_score):
372
373
374
375
376
377
            target_seq_group_metadata_list.append(
                self._create_single_target_seq_group_metadata(
                    input_seq_group_metadata,
                    input_seq_id,
                    next(target_seq_ids_iter),
                    token_ids,
378
                    sampling_params=sampling_params,
379
380
381
382
                ))

        return target_seq_group_metadata_list

383
    @staticmethod
384
385
386
387
388
    def _create_single_target_seq_group_metadata(
        seq_group_metadata: SequenceGroupMetadata,
        seq_id: SeqId,
        target_seq_id: TargetSeqId,
        token_ids: List[TokenId],
389
        sampling_params: SamplingParams,
390
391
392
393
394
395
396
397
398
399
400
    ) -> SequenceGroupMetadata:
        """Create a single target SequenceGroupMetadata.

        Args:
            seq_group_metadata: The metadata for the input sequence.
            seq_id: The input sequence ID.
            target_seq_id: The corresponding target sequence ID.
            token_ids: The list of token ids that are to be appended to the
                input sequence.
        """
        seq_data = seq_group_metadata.seq_data[seq_id]
401
        prompt_token_ids = seq_data.prompt_token_ids_array
402
        new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
403
        mrope_position_delta = seq_data.mrope_position_delta
404

405
406
407
        new_seq_data_dict = {
            target_seq_id:
            SequenceData(
408
409
410
                prompt_token_ids,
                _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                        new_output_token_ids),
411
412
413
414
415
416
417
418
            ),
        }
        # This is a hack. Technically, spec decoding should compute
        # num_lookahead slots at one shot, but instead, it expands the batch
        # and evaluate one by one right now. context_len is seq_len - 1 because
        # the kv cache is filled by a previous batch in the batch expansion.
        for data in new_seq_data_dict.values():
            data.update_num_computed_tokens(data.get_len() - 1)
419
            data.mrope_position_delta = mrope_position_delta
420

421
422
423
        return SequenceGroupMetadata(
            request_id=seq_group_metadata.request_id,
            is_prompt=seq_group_metadata.is_prompt,
424
            seq_data=new_seq_data_dict,
425
            sampling_params=sampling_params,
426
427
428
429
            block_tables={
                target_seq_id: seq_group_metadata.block_tables[seq_id],
            },
            lora_request=None,
430
            token_chunk_size=1,
431
432
        )

433
    @staticmethod
434
    def _split_scoring_output(
435
        sampler_output: SamplerOutput, num_scoring_tokens: int
436
437
438
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor], torch.Tensor, torch.Tensor,
               torch.Tensor, Optional[torch.Tensor]]:
439
440
441
442
443
444
445
446
447
        """Split the target model output into speculative and non-speculative
        output.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        #
448
449
450
451
452
453
454
        # First samples are non-speculative, latter samples are from speculative
        # scoring (prefill|decode order).
        split_sizes = (sampler_output.sampled_token_ids.numel() -
                       num_scoring_tokens, num_scoring_tokens)
        (non_spec_probs,
         spec_probs) = sampler_output.sampled_token_probs.split(split_sizes)
        (non_spec_sampled_tokens, spec_sampled_tokens
455
         ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
456
457
        (non_spec_logprobs,
         spec_logprobs) = sampler_output.logprobs.split(split_sizes)
458

459
        if sampler_output.hidden_states is not None:
460
461
            (non_spec_hidden_states, spec_hidden_states
             ) = sampler_output.hidden_states.split(split_sizes)
462
        else:
463
            non_spec_hidden_states, spec_hidden_states = None, None
464

465
466
467
        return (spec_sampled_tokens, spec_probs, spec_logprobs,
                spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
                non_spec_logprobs, non_spec_hidden_states)
468

469
    @staticmethod
470
    def _create_target_seq_id_iterator(
471
            seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
472
473
474
475
476
477
478
479
480
        """Create an iterator for creating target sequence ids.
        Target sequence ids are distinct from sequence ids because we create a
        distinct target sequence id for each proposal token to be scored.

        This implementation increments a counter starting at 1 + max of all
        provided input sequence ids.
        """
        return count(start=max(seq_ids) + 1)

481
    @staticmethod
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    def _get_token_ids_to_score(
        full_spec_token_ids: List[TokenId]  # shape: [k]
    ) -> List[List[TokenId]]:
        """Given an int tensor of proposal token ids, return a list of
        token ids that should be scored.

        Returns k+1 output lists. The additional one is used for generating the
        bonus token.

        Example:
            Input: [0, 1, 2, 3] (k=4)
            Output: (k+1 lists)
                []
                [0]
                [0, 1]
                [0, 1, 2]
                [0, 1, 2, 3]
        """
500
        empty_token_ids: List[TokenId] = []
501
502

        token_ids_to_score = [empty_token_ids]
503
504
        token_ids_to_score.extend(full_spec_token_ids[:i + 1]
                                  for i in range(len(full_spec_token_ids)))
505
        return token_ids_to_score