spec_decode_worker.py 28.8 KB
Newer Older
1
from functools import cached_property
2
from typing import Any, Dict, List, Optional, Tuple
3
4
5

import torch

6
from vllm.config import ParallelConfig, SpeculativeConfig
7
from vllm.distributed.communication_op import broadcast_tensor_dict
8
from vllm.logger import init_logger
9
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
10
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
11
12
                           HiddenStates, SamplerOutput, SequenceGroupMetadata,
                           get_all_seq_ids)
13
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
14
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
15
16
17
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
18
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
19
from vllm.spec_decode.multi_step_worker import MultiStepWorker
20
from vllm.spec_decode.ngram_worker import NGramWorker
21
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
22
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
23
from vllm.spec_decode.util import (create_sequence_group_output,
24
                                   get_all_num_logprobs,
25
                                   get_sampled_token_logprobs, nvtx_range,
26
                                   split_batch_by_proposal_len)
27
from vllm.worker.worker import Worker
28
29
30
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
31
32


33
34
35
36
37
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
    assert "speculative_config" in kwargs
38
    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    assert speculative_config is not None

    target_worker = Worker(*args, **kwargs)

    draft_worker_kwargs = kwargs.copy()
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
        model_config=speculative_config.draft_model_config,
        parallel_config=speculative_config.draft_parallel_config,
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
        # TODO allow draft-model specific load config.
        #load_config=load_config,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
    )

    return spec_decode_worker


64
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
65
66
67
68
69
70
71
72
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

73
    See https://github.com/vllm-project/vllm/pull/2188 and
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * Only lossless rejection sampling is supported. Contributions adding lossy
        verification routines are welcome (e.g. Medusa's typical acceptance).
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

92
    @classmethod
93
94
    def create_worker(
        cls,
95
        scorer_worker: Worker,
96
97
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
98
99
    ) -> "SpecDecodeWorker":

100
101
102
103
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
104

105
        disable_bonus_tokens = True
106
        if ngram_prompt_lookup_max > 0:
107
            disable_bonus_tokens = False
108
109
110
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
111
112
113
114
        elif draft_worker_kwargs[
                "model_config"].hf_config.model_type == "mlp_speculator":
            proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
            disable_bonus_tokens = False
115
        else:
116
117
118
119
120
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

121
122
            if draft_tp == 1:
                draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
123
            proposer_worker = MultiStepWorker(**draft_worker_kwargs)
124
125
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
126

127
128
129
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

130
131
132
133
134
        return SpecDecodeWorker(proposer_worker,
                                scorer_worker,
                                disable_by_batch_size=disable_by_batch_size,
                                rejection_sampler=RejectionSampler(
                                    disable_bonus_tokens=disable_bonus_tokens))
135

136
137
    def __init__(
        self,
138
        proposer_worker: ProposerWorkerBase,
139
        scorer_worker: WorkerBase,
140
141
        rejection_sampler: RejectionSampler,
        metrics_collector: Optional[AsyncMetricsCollector] = None,
142
        disable_by_batch_size: Optional[int] = None,
143
144
145
146
147
148
149
150
151
152
153
154
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
            rejection_sampler: A Torch module used to perform modified rejection
                sampling for speculative decoding.
155
156
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
157
158
159
160
161
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
162
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
163
164
165
166
167
168
169
170
171
        self.rejection_sampler = rejection_sampler

        self._metrics = AsyncMetricsCollector(
            rejection_sampler
        ) if metrics_collector is None else metrics_collector

        self.probs_dtype = self.rejection_sampler.probs_dtype
        self.token_id_dtype = self.rejection_sampler.token_id_dtype

172
173
        # Lazy initiazliation.
        self.scorer: SpeculativeScorer
174

175
176
177
178
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None

179
    def init_device(self) -> None:
180
181
182
183
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
184
185
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
186

187
188
189
190
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

191
192
193
194
195
196
197
        self._metrics.init_gpu_tensors(self.rank)
        self.rejection_sampler.init_gpu_tensors(self.rank)
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

198
199
        self._configure_model_sampler_for_spec_decode()

200
201
202
    def load_model(self, *args, **kwargs):
        pass

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
        which significantly reduces overhead of rejection sampling.

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
224
        self.proposer_worker.set_include_gpu_probs_tensor()
225

226
    def determine_num_available_blocks(self) -> Tuple[int, int]:
227
228
229
230
231
232
233
234
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
235
            self.scorer_worker.determine_num_available_blocks())
236

237
        scorer_cache_block_size_bytes = (
238
            self.scorer_worker.get_cache_block_size_bytes())
239
        proposer_cache_block_size_bytes = (
240
            self.proposer_worker.get_cache_block_size_bytes())
241
242
243
244
245
246

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

247
248
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
249
250
        """Initialize the cache engine of the scorer and proposer workers.
        """
251
252
253
254
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
255
256
257

    @torch.inference_mode()
    def execute_model(
258
259
260
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
261
262
        """Perform speculative decoding on the input batch.
        """
263
264
265
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
266

267
268
269
270
271
272
273
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
274
            return []
275

276
277
278
279
280
281
282
283
284
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots

        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
285
        # communication to inform them.
286
287
288
289
290
291
292
293
294
295
296
297
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
            disable_all_speculation=disable_all_speculation,
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

298
299
300
301
302
303
304
305
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
        # 3. No request: There are no requests in the batch.
        # In any of these cases, the proposer and scorer workers
        # are called normally.
306
        if num_lookahead_slots == 0 or len(
307
308
                execute_model_req.seq_group_metadata_list
        ) == 0 or disable_all_speculation:
309
310
311
312
313
314
315
316
317
318
319
320
321
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)

        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

322
323
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
324
325
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        disable_all_speculation = (execute_model_req.running_queue_size >=
                                   self.disable_by_batch_size)

        return disable_all_speculation

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
344
345

    @nvtx_range("spec_decode_worker._run_no_spec")
346
347
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
348
349
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
350
351
352
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
353
        """
354
355
        if not skip_proposer:
            self.proposer_worker.execute_model(execute_model_req)
356

357
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
358
359
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
360

361
362
363
364
365
366
367
368
369
370
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
                    execute_model_req.seq_group_metadata_list, hidden_states)
            else:
                self.previous_hidden_states.update(
                    execute_model_req.seq_group_metadata_list, hidden_states)

371
372
373
374
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
        sampler_output.probs = None
        sampler_output.sampled_tokens = None
375
        sampler_output.logprobs = None
376
377
        return [sampler_output]

378
    def _run_non_driver_rank(self) -> bool:
379
380
381
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
382
383

        Returns True iff there are remaining sequences to process.
384
        """
385
386
387
388
389
390
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
391
392
393
394
395
396
397

        # Even if num_lookahead_slots is zero, we want to run the proposer model
        # as it may have KV.
        #
        # We run the proposer once per lookahead slot. In the future we should
        # delegate how many times it runs to the proposer.
        for _ in range(max(num_lookahead_slots, 1)):
398
            self.proposer_worker.execute_model()
399

400
401
        self.scorer_worker.execute_model()
        return True
402

403
404
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
405
406
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
407
408
409
410
411
412
413
414
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
415
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
416

417
418
419
420
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

421
        # Generate proposals using draft worker.
422
        proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
423
424

        proposal_scores = self.scorer.score_proposals(
425
            execute_model_req,
426
427
428
            proposals,
        )

429
        accepted_token_ids, target_logprobs = self._verify_tokens(
430
431
            execute_model_req.seq_group_metadata_list, proposal_scores,
            proposals, execute_model_req.num_lookahead_slots)
432

433
        return self._create_output_sampler_list(
434
            execute_model_req.seq_group_metadata_list,
435
436
            accepted_token_ids,
            target_logprobs=target_logprobs,
437
            k=execute_model_req.num_lookahead_slots)
438
439
440
441
442
443
444
445

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
446
    ) -> Tuple[torch.Tensor, torch.Tensor]:
447
448
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
449
450
451

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

469
470
471
472
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
473
474
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

475
476
477
478
479
480
481
482
483
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

484
        accepted_token_ids = self.rejection_sampler(
485
486
487
488
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
489
490
491
492
493
494
495
496
497
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
498
        logprobs = proposal_scores.logprobs
499
500
501
502
503

        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

504
505
506
507
508
509
510
511
512
513
514
515
516
517
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
            hs_size = hidden_states.shape[1]
            hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
                                                  hs_size)
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
            self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
                                                       hidden_states)

518
        return accepted_token_ids, logprobs
519
520
521
522
523

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
524
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
525
526
527
528
529
530
531
        k: int,
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        batch_size, num_steps = accepted_token_ids.shape

        # Organize input tensors by step instead of by sequence.
        target_logprobs_by_step = target_logprobs.transpose(0, 1)
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)

        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step,
         accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )

        # Get the top-k logprobs (which may or may not include the logprob of
        # the accepted token).
        (topk_logprobs_by_step,
         topk_indices_by_step) = target_logprobs_by_step.topk(
             k=self.scorer_worker.model_config.max_logprobs,
             dim=-1,
         )

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
555
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
556
557
558
559
560
561
562
563
564
565
566
567
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

        # Serialize all tensors to CPU Python lists.
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step.tolist()
        topk_indices_by_step = topk_indices_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
568
        sampler_output_list: List[SamplerOutput] = []
569
570
571
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
572
573
                break

574
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
575
576
577
578
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]

579
                step_output_token_ids.append(
580
581
582
583
584
585
586
587
588
589
590
591
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
592
                    ))
593

594
595
596
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

597
598
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics

        return sampler_output_list

    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

625
626
627
628
    @property
    def _driver_rank(self) -> int:
        return 0

629
630
631
632
633
634
635
636
637
638
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks