spec_decode_worker.py 34.2 KB
Newer Older
1
from collections import defaultdict
2
from functools import cached_property
3
from typing import Any, Dict, List, Optional, Set, Tuple
4
5
6

import torch

7
from vllm.config import ParallelConfig, SpeculativeConfig
8
from vllm.distributed.communication_op import broadcast_tensor_dict
9
from vllm.logger import init_logger
10
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
11
from vllm.model_executor.layers.spec_decode_base_sampler import (
12
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
13
14
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
15
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
16
                           HiddenStates, SamplerOutput, SequenceGroupMetadata,
17
                           get_all_seq_ids_and_request_ids)
18
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
19
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
20
21
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
22
from vllm.spec_decode.medusa_worker import MedusaWorker
23
from vllm.spec_decode.metrics import AsyncMetricsCollector
24
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
25
from vllm.spec_decode.multi_step_worker import MultiStepWorker
26
from vllm.spec_decode.ngram_worker import NGramWorker
27
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
28
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
29
from vllm.spec_decode.util import (create_sequence_group_output,
30
                                   get_all_num_logprobs,
31
                                   get_sampled_token_logprobs, nvtx_range,
32
                                   split_batch_by_proposal_len)
33
from vllm.worker.worker import Worker
34
35
36
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
37
38


39
40
41
42
43
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
    assert "speculative_config" in kwargs
44
    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    assert speculative_config is not None

    target_worker = Worker(*args, **kwargs)

    draft_worker_kwargs = kwargs.copy()
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
        model_config=speculative_config.draft_model_config,
        parallel_config=speculative_config.draft_parallel_config,
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
        # TODO allow draft-model specific load config.
        #load_config=load_config,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
65
66
67
68
69
70
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
        typical_acceptance_sampler_posterior_alpha)
71
72
73
74

    return spec_decode_worker


75
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
76
77
78
79
80
81
82
83
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

84
    See https://github.com/vllm-project/vllm/pull/2188 and
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

101
    @classmethod
102
103
    def create_worker(
        cls,
104
        scorer_worker: Worker,
105
106
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
107
108
109
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
110
111
    ) -> "SpecDecodeWorker":

112
        allow_zero_draft_token_step = True
113
114
115
116
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
117
118
119
120
121
        if ngram_prompt_lookup_max > 0:
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
122
123
124
125
126
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

127
128
129
            if draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "mlp_speculator":
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
130
131
132
            elif draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "medusa":
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
133
134
135
136
            else:
                if draft_tp == 1:
                    draft_worker_kwargs[
                        "model_runner_cls"] = TP1DraftModelRunner
137
138
                else:
                    allow_zero_draft_token_step = False
139
140
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

141
142
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
143

144
145
146
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

147
148
149
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
            spec_decode_sampler = RejectionSampler(
150
                disable_bonus_tokens=False, )
151
152
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
153
                disable_bonus_tokens=False,
154
155
156
157
158
159
160
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
        logger.info("Configuring SpecDecodeWorker with sampler=%s",
                    type(spec_decode_sampler))

161
162
163
164
165
166
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
            allow_zero_draft_token_step=allow_zero_draft_token_step)
167

168
169
    def __init__(
        self,
170
        proposer_worker: ProposerWorkerBase,
171
        scorer_worker: WorkerBase,
172
        spec_decode_sampler: SpecDecodeBaseSampler,
173
        metrics_collector: Optional[AsyncMetricsCollector] = None,
174
        disable_by_batch_size: Optional[int] = None,
175
        allow_zero_draft_token_step: Optional[bool] = True,
176
177
178
179
180
181
182
183
184
185
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
186
187
188
189
190
191
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
192
193
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
194
195
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
196
197
198
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
199
200
201
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
202
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
203
        self.spec_decode_sampler = spec_decode_sampler
204
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
205
        self._metrics = AsyncMetricsCollector(
206
            self.spec_decode_sampler
207
        ) if metrics_collector is None else metrics_collector
208
209
210
211
212
213
214
215
216
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

217
218
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
219
        # Lazy initialization.
220
        self.scorer: SpeculativeScorer
221

222
223
224
225
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None

226
    def init_device(self) -> None:
227
228
229
230
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
231
232
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
233

234
235
236
237
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

238
        self._metrics.init_gpu_tensors(self.rank)
239
240
        self.spec_decode_sampler.init_gpu_tensors(self.rank)

241
242
243
244
245
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

246
247
        self._configure_model_sampler_for_spec_decode()

248
249
250
    def load_model(self, *args, **kwargs):
        pass

251
252
253
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
254
        which significantly reduces overhead of sampling during verification.
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
272
        self.proposer_worker.set_include_gpu_probs_tensor()
273

274
    def determine_num_available_blocks(self) -> Tuple[int, int]:
275
276
277
278
279
280
281
282
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
283
            self.scorer_worker.determine_num_available_blocks())
284

285
        scorer_cache_block_size_bytes = (
286
            self.scorer_worker.get_cache_block_size_bytes())
287
        proposer_cache_block_size_bytes = (
288
            self.proposer_worker.get_cache_block_size_bytes())
289
290
291
292
293
294

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

295
296
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
297
298
        """Initialize the cache engine of the scorer and proposer workers.
        """
299
300
301
302
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
303
304
305

    @torch.inference_mode()
    def execute_model(
306
307
308
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
309
310
        """Perform speculative decoding on the input batch.
        """
311
312
313
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
314

315
316
317
318
319
320
321
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
322
            return []
323

324
        self._track_finished_requests(execute_model_req)
325
326
327
328
329
330
331
332
333
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots

        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
334
        # communication to inform them.
335
336
337
338
339
340
341
342
343
344
345
346
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
            disable_all_speculation=disable_all_speculation,
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

347
348
349
350
351
352
353
354
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
        # 3. No request: There are no requests in the batch.
        # In any of these cases, the proposer and scorer workers
        # are called normally.
355
        if num_lookahead_slots == 0 or len(
356
357
                execute_model_req.seq_group_metadata_list
        ) == 0 or disable_all_speculation:
358
359
360
361
362
363
364
365
366
367
368
369
370
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)

        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

371
372
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
373
374
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        disable_all_speculation = (execute_model_req.running_queue_size >=
                                   self.disable_by_batch_size)

        return disable_all_speculation

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
393
394

    @nvtx_range("spec_decode_worker._run_no_spec")
395
396
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
397
398
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
399
400
401
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
402
        """
403
404
        if not skip_proposer:
            self.proposer_worker.execute_model(execute_model_req)
405

406
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
407
408
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
409

410
411
412
413
414
415
416
417
418
419
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
                    execute_model_req.seq_group_metadata_list, hidden_states)
            else:
                self.previous_hidden_states.update(
                    execute_model_req.seq_group_metadata_list, hidden_states)

420
421
422
423
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
        sampler_output.probs = None
        sampler_output.sampled_tokens = None
424
        sampler_output.logprobs = None
425
426
        return [sampler_output]

427
    def _run_non_driver_rank(self) -> bool:
428
429
430
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
431
432

        Returns True iff there are remaining sequences to process.
433
        """
434
435
436
437
438
439
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
440
441
442
443
444
445
446

        # Even if num_lookahead_slots is zero, we want to run the proposer model
        # as it may have KV.
        #
        # We run the proposer once per lookahead slot. In the future we should
        # delegate how many times it runs to the proposer.
        for _ in range(max(num_lookahead_slots, 1)):
447
            self.proposer_worker.execute_model()
448

449
450
        self.scorer_worker.execute_model()
        return True
451

452
453
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
454
455
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
456
457
458
459
460
461
462
463
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
464
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
465

466
467
468
469
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

470
        # Generate proposals using draft worker.
471
472
        proposals = self.proposer_worker.get_spec_proposals(
            execute_model_req, self._seq_with_bonus_token_in_last_step)
473

474
475
476
477
478
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")

479
        proposal_scores = self.scorer.score_proposals(
480
            execute_model_req,
481
482
483
            proposals,
        )

484
        accepted_token_ids, target_logprobs = self._verify_tokens(
485
486
            execute_model_req.seq_group_metadata_list, proposal_scores,
            proposals, execute_model_req.num_lookahead_slots)
487

488
        return self._create_output_sampler_list(
489
            execute_model_req.seq_group_metadata_list,
490
491
            accepted_token_ids,
            target_logprobs=target_logprobs,
492
            k=execute_model_req.num_lookahead_slots)
493
494
495
496
497
498
499
500

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
501
    ) -> Tuple[torch.Tensor, torch.Tensor]:
502
503
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
504
505
506

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

524
525
526
527
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
528
529
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

530
531
532
533
534
535
536
537
538
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        # Sampler arguments
        sampler_extra_kwargs = {}
        if isinstance(self.spec_decode_sampler,
                      SpecDecodeStochasticBaseSampler):

            # Get sequence group state
            generators = []
            for seq_group_metadata in seq_group_metadata_list:
                if (seq_group_metadata.state is not None
                        and seq_group_metadata.state.generator is not None):
                    generators.append(seq_group_metadata.state.generator)
                else:
                    generators.append(None)

            sampler_extra_kwargs["generators"] = generators

555
        accepted_token_ids = self.spec_decode_sampler(
556
557
558
559
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
560
            **sampler_extra_kwargs,
561
562
563
564
565
566
567
568
569
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
570
        logprobs = proposal_scores.logprobs
571
572
573
574
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

575
576
577
578
579
580
581
582
583
584
585
586
587
588
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
            hs_size = hidden_states.shape[1]
            hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
                                                  hs_size)
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
            self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
                                                       hidden_states)

589
        return accepted_token_ids, logprobs
590
591
592
593
594

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
595
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
596
597
598
599
600
601
602
        k: int,
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        batch_size, num_steps = accepted_token_ids.shape

        # Organize input tensors by step instead of by sequence.
        target_logprobs_by_step = target_logprobs.transpose(0, 1)
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)

        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step,
         accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )

        # Get the top-k logprobs (which may or may not include the logprob of
        # the accepted token).
        (topk_logprobs_by_step,
         topk_indices_by_step) = target_logprobs_by_step.topk(
             k=self.scorer_worker.model_config.max_logprobs,
             dim=-1,
         )

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
626
627
628
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

629
630
631
632
633
634
635
636
637
638
639
640
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

        # Serialize all tensors to CPU Python lists.
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step.tolist()
        topk_indices_by_step = topk_indices_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
641
        sampler_output_list: List[SamplerOutput] = []
642
643
644
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
645
646
                break

647
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
648
649
650
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
651
                step_output_token_ids.append(
652
653
654
655
656
657
658
659
660
661
662
663
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
664
665
666
667
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

668
669
670
671
672
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
673
674
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
675
676
677
678
679
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
        return sampler_output_list

680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

727
728
729
730
    @property
    def _driver_rank(self) -> int:
        return 0

731
732
733
734
735
736
737
738
739
740
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks