"vllm/vscode:/vscode.git/clone" did not exist on "28b47d1e490b2b13ec282ac1cbe0eb51f908bfbd"
spec_decode_worker.py 30.3 KB
Newer Older
1
from functools import cached_property
2
from typing import Any, Dict, List, Optional, Tuple
3
4
5

import torch

6
from vllm.config import ParallelConfig, SpeculativeConfig
7
from vllm.distributed.communication_op import broadcast_tensor_dict
8
from vllm.logger import init_logger
9
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
10
11
12
13
from vllm.model_executor.layers.spec_decode_base_sampler import (
    SpecDecodeBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
14
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
15
16
                           HiddenStates, SamplerOutput, SequenceGroupMetadata,
                           get_all_seq_ids)
17
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
18
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
19
20
21
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
22
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
23
from vllm.spec_decode.multi_step_worker import MultiStepWorker
24
from vllm.spec_decode.ngram_worker import NGramWorker
25
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
26
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
27
from vllm.spec_decode.util import (create_sequence_group_output,
28
                                   get_all_num_logprobs,
29
                                   get_sampled_token_logprobs, nvtx_range,
30
                                   split_batch_by_proposal_len)
31
from vllm.worker.worker import Worker
32
33
34
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
35
36


37
38
39
40
41
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
    assert "speculative_config" in kwargs
42
    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    assert speculative_config is not None

    target_worker = Worker(*args, **kwargs)

    draft_worker_kwargs = kwargs.copy()
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
        model_config=speculative_config.draft_model_config,
        parallel_config=speculative_config.draft_parallel_config,
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
        # TODO allow draft-model specific load config.
        #load_config=load_config,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
63
64
65
66
67
68
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
        typical_acceptance_sampler_posterior_alpha)
69
70
71
72

    return spec_decode_worker


73
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
74
75
76
77
78
79
80
81
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

82
    See https://github.com/vllm-project/vllm/pull/2188 and
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

99
    @classmethod
100
101
    def create_worker(
        cls,
102
        scorer_worker: Worker,
103
104
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
105
106
107
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
108
109
    ) -> "SpecDecodeWorker":

110
111
112
113
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
114

115
        disable_bonus_tokens = True
116
        if ngram_prompt_lookup_max > 0:
117
            disable_bonus_tokens = False
118
119
120
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
121
122
123
124
        elif draft_worker_kwargs[
                "model_config"].hf_config.model_type == "mlp_speculator":
            proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
            disable_bonus_tokens = False
125
        else:
126
127
128
129
130
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

131
132
            if draft_tp == 1:
                draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
133
            proposer_worker = MultiStepWorker(**draft_worker_kwargs)
134
135
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
136

137
138
139
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

140
141
142
143
144
145
146
147
148
149
150
151
152
153
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
            spec_decode_sampler = RejectionSampler(
                disable_bonus_tokens=disable_bonus_tokens, )
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                disable_bonus_tokens=disable_bonus_tokens,
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
        logger.info("Configuring SpecDecodeWorker with sampler=%s",
                    type(spec_decode_sampler))

154
155
156
        return SpecDecodeWorker(proposer_worker,
                                scorer_worker,
                                disable_by_batch_size=disable_by_batch_size,
157
                                spec_decode_sampler=spec_decode_sampler)
158

159
160
    def __init__(
        self,
161
        proposer_worker: ProposerWorkerBase,
162
        scorer_worker: WorkerBase,
163
        spec_decode_sampler: SpecDecodeBaseSampler,
164
        metrics_collector: Optional[AsyncMetricsCollector] = None,
165
        disable_by_batch_size: Optional[int] = None,
166
167
168
169
170
171
172
173
174
175
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
176
177
178
179
180
181
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
182
183
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
184
185
186
187
188
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
189
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
190
        self.spec_decode_sampler = spec_decode_sampler
191
        self._metrics = AsyncMetricsCollector(
192
            self.spec_decode_sampler
193
        ) if metrics_collector is None else metrics_collector
194
195
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
196
197
        # Lazy initiazliation.
        self.scorer: SpeculativeScorer
198

199
200
201
202
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None

203
    def init_device(self) -> None:
204
205
206
207
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
208
209
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
210

211
212
213
214
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

215
        self._metrics.init_gpu_tensors(self.rank)
216
217
        self.spec_decode_sampler.init_gpu_tensors(self.rank)

218
219
220
221
222
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

223
224
        self._configure_model_sampler_for_spec_decode()

225
226
227
    def load_model(self, *args, **kwargs):
        pass

228
229
230
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
231
        which significantly reduces overhead of sampling during verification.
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
249
        self.proposer_worker.set_include_gpu_probs_tensor()
250

251
    def determine_num_available_blocks(self) -> Tuple[int, int]:
252
253
254
255
256
257
258
259
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
260
            self.scorer_worker.determine_num_available_blocks())
261

262
        scorer_cache_block_size_bytes = (
263
            self.scorer_worker.get_cache_block_size_bytes())
264
        proposer_cache_block_size_bytes = (
265
            self.proposer_worker.get_cache_block_size_bytes())
266
267
268
269
270
271

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

272
273
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
274
275
        """Initialize the cache engine of the scorer and proposer workers.
        """
276
277
278
279
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
280
281
282

    @torch.inference_mode()
    def execute_model(
283
284
285
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
286
287
        """Perform speculative decoding on the input batch.
        """
288
289
290
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
291

292
293
294
295
296
297
298
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
299
            return []
300

301
302
303
304
305
306
307
308
309
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots

        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
310
        # communication to inform them.
311
312
313
314
315
316
317
318
319
320
321
322
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
            disable_all_speculation=disable_all_speculation,
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

323
324
325
326
327
328
329
330
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
        # 3. No request: There are no requests in the batch.
        # In any of these cases, the proposer and scorer workers
        # are called normally.
331
        if num_lookahead_slots == 0 or len(
332
333
                execute_model_req.seq_group_metadata_list
        ) == 0 or disable_all_speculation:
334
335
336
337
338
339
340
341
342
343
344
345
346
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)

        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

347
348
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
349
350
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        disable_all_speculation = (execute_model_req.running_queue_size >=
                                   self.disable_by_batch_size)

        return disable_all_speculation

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
369
370

    @nvtx_range("spec_decode_worker._run_no_spec")
371
372
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
373
374
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
375
376
377
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
378
        """
379
380
        if not skip_proposer:
            self.proposer_worker.execute_model(execute_model_req)
381

382
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
383
384
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
385

386
387
388
389
390
391
392
393
394
395
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
                    execute_model_req.seq_group_metadata_list, hidden_states)
            else:
                self.previous_hidden_states.update(
                    execute_model_req.seq_group_metadata_list, hidden_states)

396
397
398
399
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
        sampler_output.probs = None
        sampler_output.sampled_tokens = None
400
        sampler_output.logprobs = None
401
402
        return [sampler_output]

403
    def _run_non_driver_rank(self) -> bool:
404
405
406
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
407
408

        Returns True iff there are remaining sequences to process.
409
        """
410
411
412
413
414
415
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
416
417
418
419
420
421
422

        # Even if num_lookahead_slots is zero, we want to run the proposer model
        # as it may have KV.
        #
        # We run the proposer once per lookahead slot. In the future we should
        # delegate how many times it runs to the proposer.
        for _ in range(max(num_lookahead_slots, 1)):
423
            self.proposer_worker.execute_model()
424

425
426
        self.scorer_worker.execute_model()
        return True
427

428
429
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
430
431
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
432
433
434
435
436
437
438
439
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
440
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
441

442
443
444
445
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

446
        # Generate proposals using draft worker.
447
        proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
448
449

        proposal_scores = self.scorer.score_proposals(
450
            execute_model_req,
451
452
453
            proposals,
        )

454
        accepted_token_ids, target_logprobs = self._verify_tokens(
455
456
            execute_model_req.seq_group_metadata_list, proposal_scores,
            proposals, execute_model_req.num_lookahead_slots)
457

458
        return self._create_output_sampler_list(
459
            execute_model_req.seq_group_metadata_list,
460
461
            accepted_token_ids,
            target_logprobs=target_logprobs,
462
            k=execute_model_req.num_lookahead_slots)
463
464
465
466
467
468
469
470

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
471
    ) -> Tuple[torch.Tensor, torch.Tensor]:
472
473
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
474
475
476

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

494
495
496
497
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
498
499
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

500
501
502
503
504
505
506
507
508
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

509
        accepted_token_ids = self.spec_decode_sampler(
510
511
512
513
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
514
515
516
517
518
519
520
521
522
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
523
        logprobs = proposal_scores.logprobs
524
525
526
527
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

528
529
530
531
532
533
534
535
536
537
538
539
540
541
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
            hs_size = hidden_states.shape[1]
            hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
                                                  hs_size)
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
            self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
                                                       hidden_states)

542
        return accepted_token_ids, logprobs
543
544
545
546
547

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
548
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
549
550
551
552
553
554
555
        k: int,
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        batch_size, num_steps = accepted_token_ids.shape

        # Organize input tensors by step instead of by sequence.
        target_logprobs_by_step = target_logprobs.transpose(0, 1)
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)

        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step,
         accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )

        # Get the top-k logprobs (which may or may not include the logprob of
        # the accepted token).
        (topk_logprobs_by_step,
         topk_indices_by_step) = target_logprobs_by_step.topk(
             k=self.scorer_worker.model_config.max_logprobs,
             dim=-1,
         )

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
579
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
580
581
582
583
584
585
586
587
588
589
590
591
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

        # Serialize all tensors to CPU Python lists.
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step.tolist()
        topk_indices_by_step = topk_indices_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
592
        sampler_output_list: List[SamplerOutput] = []
593
594
595
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
596
597
                break

598
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
599
600
601
602
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]

603
                step_output_token_ids.append(
604
605
606
607
608
609
610
611
612
613
614
615
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
616
                    ))
617

618
619
620
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

621
622
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics

        return sampler_output_list

    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

649
650
651
652
    @property
    def _driver_rank(self) -> int:
        return 0

653
654
655
656
657
658
659
660
661
662
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks