spec_decode_worker.py 43 KB
Newer Older
1
from collections import defaultdict
2
from functools import cached_property
3
from typing import Any, Dict, List, Optional, Set, Tuple
4
5
6

import torch

7
from vllm.config import ParallelConfig, SpeculativeConfig
8
from vllm.distributed.communication_op import broadcast_tensor_dict
9
from vllm.logger import init_logger
10
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
11
from vllm.model_executor.layers.spec_decode_base_sampler import (
12
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
13
14
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
15
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
16
                           HiddenStates, SamplerOutput, SequenceGroupMetadata,
17
                           get_all_seq_ids, get_all_seq_ids_and_request_ids)
18
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
19
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
20
21
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
22
from vllm.spec_decode.medusa_worker import MedusaWorker
23
from vllm.spec_decode.metrics import AsyncMetricsCollector
24
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
25
from vllm.spec_decode.multi_step_worker import MultiStepWorker
26
from vllm.spec_decode.ngram_worker import NGramWorker
27
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
28
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
29
from vllm.spec_decode.target_model_runner import TargetModelRunner
30
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
31
                                   get_all_num_logprobs,
32
                                   get_sampled_token_logprobs, nvtx_range,
33
                                   split_batch_by_proposal_len)
34
from vllm.worker.worker import Worker
35
36
37
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
38
39


40
41
42
43
44
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
    assert "speculative_config" in kwargs
45
    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
46
47
    assert speculative_config is not None

48
49
50
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
51
    target_worker = Worker(*args, **kwargs)
52
53
54
55
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
        model_config=speculative_config.draft_model_config,
        parallel_config=speculative_config.draft_parallel_config,
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
        # TODO allow draft-model specific load config.
        #load_config=load_config,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
72
73
74
75
76
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
77
        typical_acceptance_sampler_posterior_alpha,
78
79
80
        disable_logprobs=speculative_config.disable_logprobs,
        disable_log_stats=speculative_config.disable_log_stats,
    )
81
82
83
84

    return spec_decode_worker


85
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
86
87
88
89
90
91
92
93
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

94
    See https://github.com/vllm-project/vllm/pull/2188 and
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

111
    @classmethod
112
113
    def create_worker(
        cls,
114
        scorer_worker: Worker,
115
116
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
117
118
119
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
120
        disable_logprobs: bool,
121
        disable_log_stats: bool,
122
123
    ) -> "SpecDecodeWorker":

124
        allow_zero_draft_token_step = True
125
126
127
128
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
129
130
131
132
133
        if ngram_prompt_lookup_max > 0:
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
134
135
136
137
138
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

139
140
141
            if draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "mlp_speculator":
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
142
143
144
            elif draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "medusa":
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
145
146
147
148
            else:
                if draft_tp == 1:
                    draft_worker_kwargs[
                        "model_runner_cls"] = TP1DraftModelRunner
149
150
                else:
                    allow_zero_draft_token_step = False
151
152
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

153
154
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
155

156
157
158
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

159
160
161
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
            spec_decode_sampler = RejectionSampler(
162
                disable_bonus_tokens=False, )
163
164
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
165
                disable_bonus_tokens=False,
166
167
168
169
170
171
172
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
        logger.info("Configuring SpecDecodeWorker with sampler=%s",
                    type(spec_decode_sampler))

173
174
175
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
176
            disable_logprobs=disable_logprobs,
177
            disable_log_stats=disable_log_stats,
178
179
180
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
            allow_zero_draft_token_step=allow_zero_draft_token_step)
181

182
183
    def __init__(
        self,
184
        proposer_worker: ProposerWorkerBase,
185
        scorer_worker: WorkerBase,
186
        spec_decode_sampler: SpecDecodeBaseSampler,
187
188
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
189
        metrics_collector: Optional[AsyncMetricsCollector] = None,
190
        disable_by_batch_size: Optional[int] = None,
191
        allow_zero_draft_token_step: Optional[bool] = True,
192
193
194
195
196
197
198
199
200
201
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
202
203
204
205
206
207
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
208
209
210
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
211
212
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
213
214
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
215
216
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
217
218
219
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
220
221
222
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
223
224
225
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
226
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
227
        self.spec_decode_sampler = spec_decode_sampler
228
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
229
        self._metrics = AsyncMetricsCollector(
230
            self.spec_decode_sampler
231
        ) if metrics_collector is None else metrics_collector
232
233
234
235
236
237
238
239
240
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

241
242
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
243
        # Lazy initialization.
244
        self.scorer: SpeculativeScorer
245

246
247
248
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
249
        self._disable_logprobs = disable_logprobs
250
        self._disable_log_stats = disable_log_stats
251

252
    def init_device(self) -> None:
253
254
255
256
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
257
258
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
259

260
261
262
263
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

264
        self._metrics.init_gpu_tensors(self.rank)
265
266
        self.spec_decode_sampler.init_gpu_tensors(self.rank)

267
268
269
270
271
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

272
273
        self._configure_model_sampler_for_spec_decode()

274
275
276
    def load_model(self, *args, **kwargs):
        pass

277
278
279
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
280
        which significantly reduces overhead of sampling during verification.
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
298
299
        (self.scorer_worker.model_runner.model.sampler.
         should_modify_greedy_probs_inplace) = True
300
        self.proposer_worker.set_include_gpu_probs_tensor()
301
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
302

303
    def determine_num_available_blocks(self) -> Tuple[int, int]:
304
305
306
307
308
309
310
311
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
312
            self.scorer_worker.determine_num_available_blocks())
313

314
        scorer_cache_block_size_bytes = (
315
            self.scorer_worker.get_cache_block_size_bytes())
316
        proposer_cache_block_size_bytes = (
317
            self.proposer_worker.get_cache_block_size_bytes())
318
319
320
321
322
323

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

324
325
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
326
327
        """Initialize the cache engine of the scorer and proposer workers.
        """
328
329
330
331
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
332
333
334

    @torch.inference_mode()
    def execute_model(
335
336
337
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
338
339
        """Perform speculative decoding on the input batch.
        """
340
341
342
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
343

344
345
346
347
348
349
350
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
351
            return []
352

353
        self._track_finished_requests(execute_model_req)
354
355
356
357
358
359
360
361
362
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots

        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
363
        # communication to inform them.
364
365
366
367
368
369
370
371
372
373
374
375
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
            disable_all_speculation=disable_all_speculation,
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

376
377
378
379
380
381
382
383
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
        # 3. No request: There are no requests in the batch.
        # In any of these cases, the proposer and scorer workers
        # are called normally.
384
        if num_lookahead_slots == 0 or len(
385
386
                execute_model_req.seq_group_metadata_list
        ) == 0 or disable_all_speculation:
387
388
389
390
391
392
393
394
395
396
397
398
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

399
400
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
401
402
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        disable_all_speculation = (execute_model_req.running_queue_size >=
                                   self.disable_by_batch_size)

        return disable_all_speculation

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
            sampler_output: SamplerOutput) -> SamplerOutput:
        """
        Creates and returns a `SamplerOutput` with only the sampled token IDs 
        being serialized to CPU & populated in `CompletionSequenceGroupOutput`.
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
            `CompletionSequenceGroupOutput` objects with only sampled token
            IDs populated.
        """
        seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list)
        sampled_token_ids_list = sampler_output.sampled_token_ids.tolist()
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
        for index, seq_id in enumerate(seq_ids):
            completion_seq_group_output_list.append(
                create_sequence_group_output(
                    token_id=sampled_token_ids_list[index][0],
                    token_id_logprob_rank=-1,
                    token_id_logprob=0.0,
                    seq_id=seq_id,
                    topk_token_ids=[],
                    topk_logprobs=[],
                ))
        return SamplerOutput(outputs=completion_seq_group_output_list)

458
    @nvtx_range("spec_decode_worker._run_no_spec")
459
460
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
461
462
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
463
464
465
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
466
        """
467
468
        if not skip_proposer:
            self.proposer_worker.execute_model(execute_model_req)
469

470
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
471
472
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
473

474
475
476
477
478
479
480
481
482
483
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
                    execute_model_req.seq_group_metadata_list, hidden_states)
            else:
                self.previous_hidden_states.update(
                    execute_model_req.seq_group_metadata_list, hidden_states)

484
485
486
487
488
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
                                    sampler_output)

489
490
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
491
492
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
493
        sampler_output.logprobs = None
494
        return [sampler_output_to_return]
495

496
    def _run_non_driver_rank(self) -> bool:
497
498
499
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
500

501
        Returns True if there are remaining sequences to process.
502
        """
503
504
505
506
507
508
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
509
510
511
512
513
514
515

        # Even if num_lookahead_slots is zero, we want to run the proposer model
        # as it may have KV.
        #
        # We run the proposer once per lookahead slot. In the future we should
        # delegate how many times it runs to the proposer.
        for _ in range(max(num_lookahead_slots, 1)):
516
            self.proposer_worker.execute_model()
517

518
519
        self.scorer_worker.execute_model()
        return True
520

521
522
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
523
524
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
525
526
527
528
529
530
531
532
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
533
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
534

535
536
537
538
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

539
540
541
542
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
543

544
545
546
547
548
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")

549
550
551
552
553
554
555
556
557
558
559
560
561
562
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

        with Timer() as verification_timer:
            accepted_token_ids, target_logprobs = self._verify_tokens(
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
563

564
        return self._create_output_sampler_list(
565
            execute_model_req.seq_group_metadata_list,
566
567
            accepted_token_ids,
            target_logprobs=target_logprobs,
568
569
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
570
571
572
573
574
575
576
577

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
578
    ) -> Tuple[torch.Tensor, torch.Tensor]:
579
580
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
581
582
583

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

601
602
603
604
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
605
606
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

607
608
609
610
611
612
613
614
615
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

616
        # Sampler arguments
617
618
619
620
621
622
623
624
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
625

626
        accepted_token_ids = self.spec_decode_sampler(
627
628
629
630
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
631
            **sampler_extra_kwargs,
632
633
634
635
636
637
638
639
640
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
641
        logprobs = proposal_scores.logprobs
642
643
644
645
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

646
647
648
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
649
650
            hs_size = hidden_states.shape[-1]

651
652
653
654
655
656
657
658
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
            self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
                                                       hidden_states)

659
        return accepted_token_ids, logprobs
660
661
662
663
664

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
665
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
666
        k: int,
667
        stage_times: Tuple[float, float, float],
668
669
670
671
672
673
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
674
675
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
696
697
698

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
699
700
701
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

702
703
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

704
        # Serialize tensor to CPU Python list.
705
706
707
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
708
        sampler_output_list: List[SamplerOutput] = []
709
710
711
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
712
713
                break

714
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
715
716
717
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
718
                step_output_token_ids.append(
719
720
721
722
723
724
725
726
727
728
729
730
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
731
732
733
734
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

735
736
737
738
739
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
740
741
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
742
743
744
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
745
746
747
748
749
750

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)

751
752
        return sampler_output_list

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

918
919
920
921
    @property
    def _driver_rank(self) -> int:
        return 0

922
923
924
925
926
927
928
929
930
931
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks