batch_expansion.py 27.6 KB
Newer Older
1
from array import array
2
from itertools import chain, count
3
from typing import Iterator, List, Optional, Tuple
4
5
6

import torch

7
from vllm import SamplingParams
8
from vllm.model_executor.layers.sampler import SamplerOutput
9
10
11
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
                           ExecuteModelRequest, SequenceData,
                           SequenceGroupMetadata, get_all_seq_ids)
12
13
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
14
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
15
from vllm.worker.worker_base import WorkerBase
16
17
18
19
20

SeqId = int
TargetSeqId = int
TokenId = int

21
22
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

class BatchExpansionTop1Scorer(SpeculativeScorer):
    """Implements a speculative scorer that uses batch expansion to get
    probabilities of speculative tokens according to the scoring model.

    Batch expansion converts a list of sequences and multiple query positions
    to a new batch of sequences, each with a single query position. This allows
    for MQA-like scoring in speculative decoding without requiring an MQA
    kernel.

    It is strictly less efficient than MQA scoring.

    It only supports scoring the top1 proposal tokens of the proposer, instead
    of topk/tree.
    """

39
40
    def __init__(self, scorer_worker: WorkerBase, device: str,
                 vocab_size: int):
41
42
43
44
45
46
47
        self._scorer_worker = scorer_worker
        self._device = device
        self._vocab_size = vocab_size

    @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
    def score_proposals(
        self,
48
        execute_model_req: ExecuteModelRequest,
49
50
51
52
53
54
55
56
57
58
59
60
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        """Score the proposed tokens via the scorer model.

        This converts each input sequence to a set of k+1 target sequences. The
        target sequences have the unique continuations to be scored and a
        unique sequence ID that is different from all input sequence ids.

        If a speculative sequence length would exceed the max model length, then
        no speculation is produced for that sequence.

        Args:
61
            execute_model_req: The execution request.
62
63
64
65
66
67
68
69
70
71
            proposals: The speculative proposals to score.
        Returns:
            SpeculativeScores: The scores of each speculative token, along with
                which sequences were ignored during scoring.
        """

        # TODO(cade) perform this on GPU to remove blocking call.
        proposal_lens_list = proposals.proposal_lens.tolist()
        proposal_token_ids_list = proposals.proposal_token_ids.tolist()

72
        # Filter the list to ignore invalid proposals.
73
74
        proposal_token_ids_list_without_skips = [
            proposals for proposals in proposal_token_ids_list
75
            if VLLM_INVALID_TOKEN_ID not in proposals
76
77
        ]

78
79
        (spec_indices, non_spec_indices, target_seq_group_metadata_list,
         num_scoring_tokens) = self._expand_batch(
80
             seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
81
             proposal_token_ids_list=proposal_token_ids_list_without_skips,
82
83
             proposal_lens_list=proposal_lens_list,
         )
84
85

        target_sampler_output = self._scorer_worker.execute_model(
86
            execute_model_req=execute_model_req.clone(
87
                seq_group_metadata_list=target_seq_group_metadata_list))
88
89
        assert len(target_sampler_output) == 1, "expected single-step output"
        target_sampler_output = target_sampler_output[0]
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        if not non_spec_indices:
            # All sequence groups in batch have spec decoding enabled
            contracted = self._contract_batch_all_spec(
                target_sampler_output=target_sampler_output,
                proposals=proposals,
            )
        else:
            # Batch has a mix of spec decode enabled and disabled seq groups
            contracted = self._contract_batch(
                contracted_bs=len(execute_model_req.seq_group_metadata_list),
                target_sampler_output=target_sampler_output,
                proposals=proposals,
                num_scoring_tokens=num_scoring_tokens,
                non_spec_indices=non_spec_indices,
                spec_indices=spec_indices,
                k=execute_model_req.num_lookahead_slots,
            )

        all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
110
111
112
        return SpeculativeScores(
            probs=all_probs,
            token_ids=all_tokens,
113
            logprobs=spec_logprobs,
114
            hidden_states=all_hidden_states,
115
            logits=target_sampler_output.logits,
116
117
118
119
120
        )

    def _expand_batch(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
121
        proposal_token_ids_list: List[List[TokenId]],
122
123
124
125
126
127
128
129
130
131
132
        proposal_lens_list: List[int],
    ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
        """Given the input sequences and potentially multiple corresponding
        proposal tokens, create a new batch where each sequence has a single
        query token.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
133
134
135
        (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
            split_batch_by_proposal_len(
                seq_group_metadata_list, proposal_lens_list)
136
137

        target_seq_group_metadata_list = self._create_scoring_model_input(
138
139
140
141
142
143
144
145
            seq_group_metadata_list=spec_seqs,
            proposal_token_ids=proposal_token_ids_list,
            # NOTE: We determine the seq ids in the expanded batch using the
            # full seq_group_metadata_list, instead of only spec_seqs.
            target_seq_ids_iter=self._create_target_seq_id_iterator(
                seq_ids=get_all_seq_ids(seq_group_metadata_list)),
        )

146
147
148
        num_scoring_tokens = len(target_seq_group_metadata_list)
        target_seq_group_metadata_list.extend(non_spec_seqs)

149
150
        return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
                num_scoring_tokens)
151

152
    def _contract_batch(
153
154
155
156
157
        self, contracted_bs: int, target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals, num_scoring_tokens: int,
        non_spec_indices: List[int], spec_indices: List[int], k: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor]]:
158
159
160
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.
161

162
163
164
        contracted_bs is the original batch size, and the batch size that the
        target_sampler_output will be contracted to.
        """
165
        (target_token_ids, target_probs, target_logprobs, target_hidden_states,
166
         non_spec_target_token_ids, non_spec_target_probs,
167
168
         non_spec_target_logprobs,
         non_spec_target_hidden_states) = self._split_scoring_output(
169
170
171
172
             target_sampler_output, num_scoring_tokens)

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
173
174
175
176
177
        expanded_batch_size, k = proposals.proposal_token_ids.shape

        # The number of tokens in the expanded batch used for speculation is
        # equal to the total expanded batch size minus the number of samples for
        # non-speculative sequences.
178
        non_spec_expanded_bs = len(non_spec_target_token_ids)
179
        spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
180

181
182
183
184
185
        target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
        target_probs = target_probs.reshape(*target_token_ids.shape,
                                            self._vocab_size)
        target_logprobs = target_logprobs.reshape(target_probs.shape)

186
187
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
188
                *target_token_ids.shape, target_hidden_states.shape[-1])
189

190
191
192
193
194
        all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
                                               fill_value=-1)
        all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                fill_value=-float("inf"))
195

196
197
198
199
200
201
        if target_sampler_output.hidden_states is not None:
            all_hidden_states = target_hidden_states.new_zeros(
                size=(contracted_bs, k + 1, target_hidden_states.shape[-1]))
        else:
            all_hidden_states = None

202
        if non_spec_indices:
203
204
205
206
207
208
            all_tokens[non_spec_indices, :1] = \
                non_spec_target_token_ids.unsqueeze(1)
            all_probs[non_spec_indices, :1, :] = \
                non_spec_target_probs.unsqueeze(1)
            all_logprobs[non_spec_indices, :1, :] = \
                non_spec_target_logprobs.unsqueeze(1)
209
            if all_hidden_states is not None:
210
211
212
                assert non_spec_target_hidden_states is not None
                all_hidden_states[non_spec_indices, :1, :] = \
                    non_spec_target_hidden_states.unsqueeze(1)
213

214
215
216
        if spec_indices:
            all_tokens[spec_indices] = target_token_ids
            all_probs[spec_indices] = target_probs
217
            all_logprobs[spec_indices] = target_logprobs
218
219
220
221
            if all_hidden_states is not None:
                all_hidden_states[spec_indices] = target_hidden_states

        return all_tokens, all_probs, all_logprobs, all_hidden_states
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def _contract_batch_all_spec(
        self,
        target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor]]:
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        It assumes all sequences in the batch were previously expanded.
        """

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
        contracted_bs, k = proposals.proposal_token_ids.shape

        # Reshape tensors to original batch size
        target_token_ids = target_sampler_output.sampled_token_ids.reshape(
            contracted_bs, k + 1)
        target_probs = target_sampler_output.sampled_token_probs.reshape(
            *target_token_ids.shape, self._vocab_size)
        target_logprobs = target_sampler_output.logprobs.reshape(
            target_probs.shape)
        target_hidden_states = target_sampler_output.hidden_states
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

        return (target_token_ids, target_probs, target_logprobs,
                target_hidden_states)

255
    def _create_scoring_model_input(
256
257
258
259
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
        target_seq_ids_iter: Iterator[TargetSeqId],
260
261
262
    ) -> List[SequenceGroupMetadata]:
        """Given the original input sequences and proposed tokens from the draft
        model, create a list of target sequences that can be used for scoring.
263
264
265
266

        target_seq_ids_iter provides sequence ids for the expanded batch,
        fulfilling the requirement that no seq id in the expanded batch is equal
        to the seq id in the original batch.
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        """

        if not seq_group_metadata_list:
            return []

        target_seq_group_metadata = list(
            chain.from_iterable(
                self._create_target_seq_group_metadata(
                    seq_group_metadata,
                    proposal_token_ids,
                    i,
                    target_seq_ids_iter,
                ) for i, seq_group_metadata in enumerate(
                    seq_group_metadata_list)))

        return target_seq_group_metadata

    def _create_target_seq_group_metadata(
        self,
        input_seq_group_metadata: SequenceGroupMetadata,
287
        proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        batch_index: int,
        target_seq_ids_iter: Iterator[TargetSeqId],
    ) -> List[SequenceGroupMetadata]:
        """Given an input sequence group metadata and a list of draft tokens,
        create a list of target SequenceGroupMetadata, one for each
        token id that needs to be scored.

        Naive speculative decoding requires K target model scores, one for each
        draft model token. However one can add a bonus token such that if each
        token is accepted, then a final token may be sampled from the model.
        This function creates K+1 target SequenceGroupMetadata to take
        advantage of the bonus token.
        """
        assert not input_seq_group_metadata.is_prompt, (
            "Speculating on "
            "prompts not yet supported")
        assert len(input_seq_group_metadata.seq_data) == 1, (
            "Beam search "
            "not supported in speculative decoding")
        input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))

        token_ids_to_score = self._get_token_ids_to_score(
            proposal_token_ids[batch_index])

312
313
314
315
316
317
318
319
320
321
        # Use simpler sampling parameters apart from for final token
        # (in particular don't do seeded sampling) since those sampled tokens
        # aren't used.
        # We don't replace the sampling_params in the greedy case because
        # this also controls whether the probs get modified in the sampler
        # (see use of _modify_greedy_probs_inplace there).
        sampling_params = input_seq_group_metadata.sampling_params
        non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
            if sampling_params.temperature else sampling_params

322
        target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
323
324
325
326
        last_index = len(token_ids_to_score) - 1
        for i, token_ids in enumerate(token_ids_to_score):
            target_sampling_params = sampling_params if i == last_index \
                else non_bonus_sampling_params
327
328
329
330
331
332
            target_seq_group_metadata_list.append(
                self._create_single_target_seq_group_metadata(
                    input_seq_group_metadata,
                    input_seq_id,
                    next(target_seq_ids_iter),
                    token_ids,
333
                    sampling_params=target_sampling_params,
334
335
336
337
                ))

        return target_seq_group_metadata_list

338
    @staticmethod
339
340
341
342
343
    def _create_single_target_seq_group_metadata(
        seq_group_metadata: SequenceGroupMetadata,
        seq_id: SeqId,
        target_seq_id: TargetSeqId,
        token_ids: List[TokenId],
344
        sampling_params: SamplingParams,
345
346
347
348
349
350
351
352
353
354
355
    ) -> SequenceGroupMetadata:
        """Create a single target SequenceGroupMetadata.

        Args:
            seq_group_metadata: The metadata for the input sequence.
            seq_id: The input sequence ID.
            target_seq_id: The corresponding target sequence ID.
            token_ids: The list of token ids that are to be appended to the
                input sequence.
        """
        seq_data = seq_group_metadata.seq_data[seq_id]
356
        prompt_token_ids = seq_data.prompt_token_ids_array
357
358
        new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]

359
360
361
        new_seq_data_dict = {
            target_seq_id:
            SequenceData(
362
363
364
                prompt_token_ids,
                _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                        new_output_token_ids),
365
366
367
368
369
370
371
372
373
            ),
        }
        # This is a hack. Technically, spec decoding should compute
        # num_lookahead slots at one shot, but instead, it expands the batch
        # and evaluate one by one right now. context_len is seq_len - 1 because
        # the kv cache is filled by a previous batch in the batch expansion.
        for data in new_seq_data_dict.values():
            data.update_num_computed_tokens(data.get_len() - 1)

374
375
376
        return SequenceGroupMetadata(
            request_id=seq_group_metadata.request_id,
            is_prompt=seq_group_metadata.is_prompt,
377
            seq_data=new_seq_data_dict,
378
            sampling_params=sampling_params,
379
380
381
382
            block_tables={
                target_seq_id: seq_group_metadata.block_tables[seq_id],
            },
            lora_request=None,
383
            token_chunk_size=1,
384
385
        )

386
    @staticmethod
387
    def _split_scoring_output(
388
        sampler_output: SamplerOutput, num_scoring_tokens: int
389
390
391
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor], torch.Tensor, torch.Tensor,
               torch.Tensor, Optional[torch.Tensor]]:
392
393
394
395
396
397
398
399
400
401
402
        """Split the target model output into speculative and non-speculative
        output.
        """

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        #
        # First samples are from speculative scoring, latter samples are non-
        # speculative samples.
403
404
405
        split_sizes = (num_scoring_tokens,
                       sampler_output.sampled_token_ids.numel() -
                       num_scoring_tokens)
406
407
408
409
        (spec_probs, non_spec_probs
         ) = sampler_output.sampled_token_probs.split(split_sizes)
        (spec_sampled_tokens, non_spec_sampled_tokens
         ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
410
411
412
413
        (
            spec_logprobs,
            non_spec_logprobs,
        ) = sampler_output.logprobs.split(split_sizes)
414

415
416
417
418
419
420
421
422
        if sampler_output.hidden_states is not None:
            (
                spec_hidden_states,
                non_spec_hidden_states,
            ) = sampler_output.hidden_states.split(split_sizes)
        else:
            spec_hidden_states, non_spec_hidden_states = None, None

423
424
425
        return (spec_sampled_tokens, spec_probs, spec_logprobs,
                spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
                non_spec_logprobs, non_spec_hidden_states)
426

427
    @staticmethod
428
    def _create_target_seq_id_iterator(
429
            seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
430
431
432
433
434
435
436
437
438
        """Create an iterator for creating target sequence ids.
        Target sequence ids are distinct from sequence ids because we create a
        distinct target sequence id for each proposal token to be scored.

        This implementation increments a counter starting at 1 + max of all
        provided input sequence ids.
        """
        return count(start=max(seq_ids) + 1)

439
    @staticmethod
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def _get_token_ids_to_score(
        full_spec_token_ids: List[TokenId]  # shape: [k]
    ) -> List[List[TokenId]]:
        """Given an int tensor of proposal token ids, return a list of
        token ids that should be scored.

        Returns k+1 output lists. The additional one is used for generating the
        bonus token.

        Example:
            Input: [0, 1, 2, 3] (k=4)
            Output: (k+1 lists)
                []
                [0]
                [0, 1]
                [0, 1, 2]
                [0, 1, 2, 3]
        """
458
        empty_token_ids: List[TokenId] = []
459
460

        token_ids_to_score = [empty_token_ids]
461
462
        token_ids_to_score.extend(full_spec_token_ids[:i + 1]
                                  for i in range(len(full_spec_token_ids)))
463
        return token_ids_to_score
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    
class BatchExpansionTreeStyleScorer(BatchExpansionTop1Scorer):

    def __init__(self, scorer_worker: WorkerBase, device: str,
                 vocab_size: int):
        super().__init__(scorer_worker, device, vocab_size)

    def _contract_batch(
            self, contracted_bs: int, target_sampler_output: SamplerOutput,
            proposals: SpeculativeProposals, num_scoring_tokens: int,
            non_spec_indices: List[int], spec_indices: List[int], k: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor]]:
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        contracted_bs is the original batch size, and the batch size that the
        target_sampler_output will be contracted to.
        """
        (target_token_ids, target_probs, target_logprobs, target_hidden_states,
         non_spec_target_token_ids, non_spec_target_probs,
         non_spec_target_logprobs,
         non_spec_target_hidden_states) = self._split_scoring_output(
            target_sampler_output, num_scoring_tokens)

        # Map distinct sequences used to score each token
        # of shape [batch_size * k] back to [batch_size, k].
        expanded_batch_size, k = proposals.proposal_token_ids.shape

        # The number of tokens in the expanded batch used for speculation is
        # equal to the total expanded batch size minus the number of samples for
        # non-speculative sequences.
        non_spec_expanded_bs = len(non_spec_target_token_ids)
        spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs

        target_token_ids = target_token_ids.reshape(spec_expanded_bs, k)
        target_probs = target_probs.reshape(*target_token_ids.shape,
                                            self._vocab_size)
        target_logprobs = target_logprobs.reshape(target_probs.shape)

        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

        all_tokens = target_token_ids.new_full(size=(contracted_bs, k),
                                               fill_value=-1)
        all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
        all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                fill_value=-float("inf"))

        if target_sampler_output.hidden_states is not None:
            all_hidden_states = target_hidden_states.new_zeros(
                size=(contracted_bs, k, target_hidden_states.shape[-1]))
        else:
            all_hidden_states = None

        if non_spec_indices:
            all_tokens[non_spec_indices, :1] = \
                non_spec_target_token_ids.unsqueeze(1)
            all_probs[non_spec_indices, :1, :] = \
                non_spec_target_probs.unsqueeze(1)
            all_logprobs[non_spec_indices, :1, :] = \
                non_spec_target_logprobs.unsqueeze(1)
            if all_hidden_states is not None:
                assert non_spec_target_hidden_states is not None
                all_hidden_states[non_spec_indices, :1, :] = \
                    non_spec_target_hidden_states.unsqueeze(1)

        if spec_indices:
            all_tokens[spec_indices] = target_token_ids
            all_probs[spec_indices] = target_probs
            all_logprobs[spec_indices] = target_logprobs
            if all_hidden_states is not None:
                all_hidden_states[spec_indices] = target_hidden_states

        return all_tokens, all_probs, all_logprobs, all_hidden_states

    def _contract_batch_all_spec(
        self,
        target_sampler_output: SamplerOutput,
        proposals: SpeculativeProposals,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
               Optional[torch.Tensor]]:
        """Contract the expanded batch back into its original size.
        This maps the scores of speculative tokens back to their original
        sequences.

        It assumes all sequences in the batch were previously expanded.
        """

        # Map distinct sequences used to score each token
        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
        contracted_bs, k = proposals.proposal_token_ids.shape

        # Reshape tensors to original batch size
        target_token_ids = target_sampler_output.sampled_token_ids.reshape(
            contracted_bs, k)
        target_probs = target_sampler_output.sampled_token_probs.reshape(
            *target_token_ids.shape, self._vocab_size)
        target_logprobs = target_sampler_output.logprobs.reshape(
            target_probs.shape)
        target_hidden_states = target_sampler_output.hidden_states
        if target_hidden_states is not None:
            target_hidden_states = target_hidden_states.reshape(
                *target_token_ids.shape, target_hidden_states.shape[-1])

        return (target_token_ids, target_probs, target_logprobs,
                target_hidden_states)

    @staticmethod
    def _create_single_target_seq_group_metadata(
            seq_group_metadata: SequenceGroupMetadata,
            seq_id: SeqId,
            target_seq_id: TargetSeqId,
            token_ids: List[TokenId],
            sampling_params: SamplingParams,
    ) -> SequenceGroupMetadata:
        """Create a single target SequenceGroupMetadata.

        Args:
            seq_group_metadata: The metadata for the input sequence.
            seq_id: The input sequence ID.
            target_seq_id: The corresponding target sequence ID.
            token_ids: The list of token ids that are to be appended to the
                input sequence.
        """
        seq_data = seq_group_metadata.seq_data[seq_id]
        prompt_token_ids = seq_data.prompt_token_ids_array

        # first step need to ignore output token generated by prefill phase
595
        if seq_data.get_first_step_flag():
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
            new_output_token_ids = [*seq_data.get_output_token_ids()[:-1], *token_ids]
        else:
            new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]

        new_seq_data_dict = {
            target_seq_id:
                SequenceData(
                    prompt_token_ids,
                    _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                            new_output_token_ids),
                ),
        }
        # This is a hack. Technically, spec decoding should compute
        # num_lookahead slots at one shot, but instead, it expands the batch
        # and evaluate one by one right now. context_len is seq_len - 1 because
        # the kv cache is filled by a previous batch in the batch expansion.
        for data in new_seq_data_dict.values():
            data.update_num_computed_tokens(data.get_len() - 1)

        return SequenceGroupMetadata(
            request_id=seq_group_metadata.request_id,
            is_prompt=seq_group_metadata.is_prompt,
            seq_data=new_seq_data_dict,
            sampling_params=sampling_params,
            block_tables={
                target_seq_id: seq_group_metadata.block_tables[seq_id],
            },
            lora_request=None,
            token_chunk_size=1,
        )

    def _get_token_ids_to_score(
            self,
            full_spec_token_ids: List[TokenId]  # shape: [k]
    ) -> List[List[TokenId]]:
        """Given an int tensor of proposal token ids, return a list of
        token ids that should be scored.

        Returns k+1 output lists. The additional one is used for generating the
        bonus token.

        Example:
            Input: [0, 1, 2, 3] (k=4)
            Output: (k+1 lists)
                [0]
                [0, 1]
                [0, 1, 2]
                [0, 1, 2, 3]
        """
        token_ids_to_score = []
        token_ids_to_score.extend([
            full_spec_token_ids[:i + 1]
            for i in range(len(full_spec_token_ids))
        ])
        return token_ids_to_score