spec_decode_worker.py 59.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import copy
4
from collections import defaultdict
5
from functools import cached_property
6
from typing import Any, Dict, List, Optional, Set, Tuple, Type
7
8

import torch
9
import torch.nn as nn
10

11
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
12
from vllm.distributed.communication_op import broadcast_tensor_dict
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
15
from vllm.model_executor.layers.sampler import SamplerOutput
16
from vllm.model_executor.layers.spec_decode_base_sampler import (
17
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
18
19
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
20
from vllm.platforms import current_platform
21
22
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
23
                           HiddenStates, SequenceGroupMetadata,
24
                           get_all_seq_ids_and_request_ids)
25
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
26
27
28
29

if current_platform.is_cuda_alike():
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

30
31
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
32
from vllm.spec_decode.medusa_worker import MedusaWorker
33
from vllm.spec_decode.metrics import AsyncMetricsCollector
34
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
35
from vllm.spec_decode.mqa_scorer import MQAScorer
36
from vllm.spec_decode.multi_step_worker import MultiStepWorker
37
from vllm.spec_decode.ngram_worker import NGramWorker
38
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
39
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
40
from vllm.spec_decode.target_model_runner import TargetModelRunner
41
42
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
43
                                   get_all_num_logprobs,
44
                                   get_sampled_token_logprobs, nvtx_range,
45
                                   split_batch_by_proposal_len)
46
47
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
48
49

logger = init_logger(__name__)
50
51


52
53
54
55
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
56
57
    vllm_config: VllmConfig = kwargs.get("vllm_config")
    speculative_config: SpeculativeConfig = vllm_config.speculative_config
58
59
    assert speculative_config is not None

60
61
62
63
    if vllm_config.parallel_config.pipeline_parallel_size > 1:
        raise NotImplementedError("Speculative decoding is currently "
                                  "incompatible with pipeline parallelism")

64
65
66
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
67
68
69
    target_worker_config = copy.deepcopy(vllm_config)
    target_worker_config.parallel_config.worker_cls =\
        target_worker_config.parallel_config.sd_worker_cls
70
71
72
    cls = resolve_obj_by_qualname(
        target_worker_config.parallel_config.worker_cls)
    target_worker = cls(*args, **kwargs)
73
74
75
76
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
77

78
79
    draft_worker_config = copy.deepcopy(vllm_config)
    draft_worker_config.model_config = speculative_config.draft_model_config
80
81
82
83
    draft_worker_config.quant_config = VllmConfig._get_quantization_config(
        draft_worker_config.model_config,
        vllm_config.load_config,
    )
84
85
    speculative_config.draft_parallel_config.worker_cls =\
        draft_worker_config.parallel_config.sd_worker_cls
86
87
88
    draft_worker_config.parallel_config = speculative_config.draft_parallel_config  # noqa
    # TODO allow draft-model specific load config.

89
90
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
91
        vllm_config=draft_worker_config,
92
93
94
95
96
97
98
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
99
        disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
100
101
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
102
103
104
105
106
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
107
        typical_acceptance_sampler_posterior_alpha,
108
109
110
        disable_logprobs=speculative_config.disable_logprobs,
        disable_log_stats=speculative_config.disable_log_stats,
    )
111
112
113
114

    return spec_decode_worker


115
# Reminder: Please update docs/source/features/compatibility_matrix.md
116
# If the feature combo become valid
117
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
118
119
120
121
122
123
124
125
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

126
    See https://github.com/vllm-project/vllm/pull/2188 and
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

143
    @classmethod
144
145
    def create_worker(
        cls,
146
        scorer_worker: WorkerBase,
147
        draft_worker_kwargs: Dict[str, Any],
148
        disable_mqa_scorer: bool,
149
        disable_by_batch_size: Optional[int],
150
151
152
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
153
        disable_logprobs: bool,
154
        disable_log_stats: bool,
155
156
    ) -> "SpecDecodeWorker":

157
        allow_zero_draft_token_step = True
158
159
160
161
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
162
163
164
        draft_model_config = draft_worker_kwargs["vllm_config"].model_config
        draft_parallel_config: ParallelConfig = draft_worker_kwargs[
            'vllm_config'].parallel_config
165
        if ngram_prompt_lookup_max > 0:
166
167
            draft_worker_kwargs[
                "device_type"] = scorer_worker.device_config.device.type
168
169
170
171
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
172
173
174
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

175
            if draft_model_config.hf_config.model_type == "mlp_speculator":
176
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
177
            elif draft_model_config.hf_config.model_type == "medusa":
178
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
179
180
            else:
                if draft_tp == 1:
181
182
183
                    if current_platform.is_cuda_alike():
                        draft_worker_kwargs[
                            "model_runner_cls"] = TP1DraftModelRunner
184
                else:
185
                    if draft_model_config.hf_config.model_type == "eagle":
186
187
188
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

189
                    allow_zero_draft_token_step = False
190
191
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

192
193
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
194

195
196
197
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

198
199
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
200
            spec_decode_sampler = RejectionSampler()
201
202
203
204
205
206
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
207
208
209
210
211
212
        logger.info(
            "[Speculative Decoding] Configuring"
            " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))

        if not disable_mqa_scorer:
            if scorer_worker.model_runner.attn_backend.get_name(
213
            ) != "FLASH_ATTN":
214
215
216
217
218
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "MQA is only available with flash attn backend.")

219
220
            if draft_model_config and \
                draft_model_config.max_model_len < \
221
222
223
224
225
226
227
228
229
230
231
232
                    scorer_worker.model_config.max_model_len:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "draft model max_model_len is smaller than the target "
                    "model max_model_len.")

            if not scorer_worker.model_runner.model_config.enforce_eager:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "target model is not running in eager mode.")
233

234
235
236
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
237
            disable_mqa_scorer=disable_mqa_scorer,
238
            disable_logprobs=disable_logprobs,
239
            disable_log_stats=disable_log_stats,
240
241
242
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
            allow_zero_draft_token_step=allow_zero_draft_token_step)
243

244
245
    def __init__(
        self,
246
        proposer_worker: ProposerWorkerBase,
247
        scorer_worker: WorkerBase,
248
        spec_decode_sampler: SpecDecodeBaseSampler,
249
        disable_mqa_scorer: bool = False,
250
251
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
252
        metrics_collector: Optional[AsyncMetricsCollector] = None,
253
        disable_by_batch_size: Optional[int] = None,
254
        allow_zero_draft_token_step: Optional[bool] = True,
255
256
257
258
259
260
261
262
263
264
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
265
266
267
268
269
270
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
271
272
            disable_mqa_scorer: If set to True, disable the MQA scorer and use
                the BatchExpansionTop1Scorer instead.
273
274
275
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
276
277
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
278
279
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
280
281
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
282
283
284
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
285
286
287
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
288
289
290
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
291
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
292
        self.spec_decode_sampler = spec_decode_sampler
293
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
294
        self._metrics = AsyncMetricsCollector(
295
            self.spec_decode_sampler
296
        ) if metrics_collector is None else metrics_collector
297
298
299
300
301
302
303
304
305
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

306
307
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
308
        # Lazy initialization.
309
        self.scorer: SpeculativeScorer
310
        self.disable_mqa_scorer = disable_mqa_scorer
311

312
313
314
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
315
        self._disable_logprobs = disable_logprobs
316
        self._disable_log_stats = disable_log_stats
317

318
    def init_device(self) -> None:
319
320
321
322
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
323
324
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
325

326
327
328
329
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

330
331
332
        self._metrics.init_tensors(self.rank, device_type=self.device)
        self.spec_decode_sampler.init_tensors(self.rank,
                                              device_type=self.device)
333

334
335
336
337
338
339
340
341
342
343
344
345
346
        scorer_cls: Type[SpeculativeScorer]
        if self.disable_mqa_scorer:
            scorer_cls = BatchExpansionTop1Scorer
            logger.info("[Speculative Decoding] Use batch "
                        "expansion for scoring proposals.")
        else:
            scorer_cls = MQAScorer
            logger.info(
                "[Speculative Decoding] Use MQA scorer for scoring proposals.")

        self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
                                 device=self.device,
                                 vocab_size=self._vocab_size)
347

348
349
        self._configure_model_sampler_for_spec_decode()

350
351
352
    def load_model(self, *args, **kwargs):
        pass

353
354
355
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
356
        which significantly reduces overhead of sampling during verification.
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
374
375
        (self.scorer_worker.model_runner.model.sampler.
         should_modify_greedy_probs_inplace) = True
376
        self.proposer_worker.set_include_gpu_probs_tensor()
377
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
378

379
    def determine_num_available_blocks(self) -> Tuple[int, int]:
380
381
382
383
384
385
386
387
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
388
            self.scorer_worker.determine_num_available_blocks())
389

390
        scorer_cache_block_size_bytes = (
391
            self.scorer_worker.get_cache_block_size_bytes())
392
        proposer_cache_block_size_bytes = (
393
            self.proposer_worker.get_cache_block_size_bytes())
394
395
396
397
398
399

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

400
401
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
402
403
        """Initialize the cache engine of the scorer and proposer workers.
        """
404
405
406
407
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
408

409
410
411
    def get_model(self) -> nn.Module:
        return self.scorer_worker.get_model()

412
413
    @torch.inference_mode()
    def execute_model(
414
415
416
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
417
418
        """Perform speculative decoding on the input batch.
        """
419
420
421
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
422

423
424
425
426
427
428
429
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
430
            return []
431

432
        self._track_finished_requests(execute_model_req)
433
434
435
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        all_prompt = True
        atleast_one_prompt = False
        all_zero_spec_tokens = True
        for sgm in execute_model_req.seq_group_metadata_list:
            all_prompt = all_prompt and sgm.is_prompt
            atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
            all_zero_spec_tokens = all_zero_spec_tokens and (
                sgm.num_speculative_tokens == 0)

        if all_prompt and execute_model_req.seq_group_metadata_list:
            assert num_lookahead_slots == 0, (
                "Prompt only runs should have num_lookahead_slots equal to 0. "
                "This should never happen, please file a bug at "
                "https://github.com/vllm-project/vllm/issues")
450
451
452
453
454
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
455
456
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
457
458
        # In any of these cases, the proposer and scorer workers
        # are called normally.
459
        # We expect `num_speculative_tokens` to be None for prefills.
460
461
        no_spec = (num_lookahead_slots == 0 or disable_all_speculation
                   or all_zero_spec_tokens)
462

463
464
465
466
467
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
468
        # communication to inform them.
469
470
471
472
473
474
475

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
476
477
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
478
            no_spec=no_spec,
479
            disable_all_speculation=disable_all_speculation,
480
481
482
483
484
485
486
487
488
            # When both chunked prefill and speculative decoding are enabled
            # it is possible that the same batch contains both prefill
            # and decodes. If that happens in the scorer we run the batch
            # as one single forward pass. However, in the proposer we
            # run them as 2 different batches - one for prefill and
            # the other for decodes. The variable indicates to the non-driver
            # worker that there are prefills as part of the speculative batch
            # and hence it needs to run an extra prefill forward pass.
            run_spec_proposer_for_prefill=atleast_one_prompt,
489
490
491
492
493
494
495
496
497
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

498
        if no_spec:
499
500
501
502
503
504
505
506
507
508
509
510
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

511
512
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
513
514
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
515
516
        return (execute_model_req.running_queue_size
                >= self.disable_by_batch_size)
517
518
519
520
521
522
523
524
525
526
527
528
529
530

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
531

532
533
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
534
            sampler_output: SamplerOutput) -> List[SamplerOutput]:
535
        """
536
537
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
538
539
540
541
542
543
544
545
546
547
548
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
549
550
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
551
        """
552
553
554
555
556
557
558
559
560
561
562
563
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

564
        seq_data_entries = [
565
566
567
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
568
        ]
569
570
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
571
572
573
        output_index = 0
        # Make sure the non-terminal prefill chunks are still aligned with
        # their own empty output.
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        for idx, seq_group_meta in enumerate(
                execute_model_req.seq_group_metadata_list):
            needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
            seq_id, seq_data = seq_data_entries[idx]
            if needs_prompt_logprobs:
                prompt_token_ids = seq_data.get_prompt_token_ids()

                # Some of these sequences may belong to non-terminal chunks,
                # which may still have to report logprobs for prompts.
                start = 1 if seq_data._num_computed_tokens == 0 \
                    else seq_data._num_computed_tokens
                end = (seq_data._num_computed_tokens + \
                       seq_group_meta.token_chunk_size)
                prompt_token_ids = prompt_token_ids[start:end]
                prompt_logprobs = [
                    create_logprobs_output(
                        token_id=p_token_id,
591
592
593
594
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
                        topk_token_ids=[],
                        topk_logprobs=[],
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                    ) for p_token_id in prompt_token_ids
                ]
            else:
                prompt_logprobs = None

            # Since we can get chunks here, we dont always have a sampled token
            # (only on last chunk) but we still have to provide an output.
            if not seq_group_meta.do_sample:
                completion_seq_group_output_list.append(
                    CompletionSequenceGroupOutput(
                        samples=[], prompt_logprobs=prompt_logprobs))
                continue

            # Sequence with output.
            completion_seq_group_output_list.append(
                create_sequence_group_output(
                    token_id=sampled_token_ids_list[output_index][0],
                    token_id_logprob_rank=-1,
                    token_id_logprob=0.0,
                    seq_id=seq_id,
                    topk_token_ids=[],
                    topk_logprobs=[],
                    prompt_logprobs=prompt_logprobs))
            output_index += 1
619
620

        return [SamplerOutput(outputs=completion_seq_group_output_list)]
621

622
    @nvtx_range("spec_decode_worker._run_no_spec")
623
624
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
625
626
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
627
628
629
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
630
631
        """

632
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
633
634
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
635

636
        # Store hidden states from target model execution, BxD.
637
638
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
639
640
641
642
643
644
645
            # Only decodes and prefill terminal chunks need a hidden state.
            seq_group_meta_with_hidden = [
                sg for sg in execute_model_req.seq_group_metadata_list
                if sg.do_sample
            ]
            if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
                # Drop hidden_states with no prediction (eg non-terminal chunks)
646
647
648
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
649
650
            if self.previous_hidden_states is None and len(
                    seq_group_meta_with_hidden):
651
                self.previous_hidden_states = HiddenStates(
652
653
654
655
656
                    hidden_states, seq_group_meta_with_hidden)
            elif self.previous_hidden_states and len(
                    seq_group_meta_with_hidden):
                self.previous_hidden_states.update(hidden_states,
                                                   seq_group_meta_with_hidden)
657
658
659
660
661
662
663
664
665
666

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
667

668
669
670
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
671
                                    [sampler_output])
672

673
674
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
675
676
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
677
        sampler_output.logprobs = None
678
        return sampler_output_to_return
679

680
    def _run_non_driver_rank(self) -> bool:
681
682
683
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
684

685
        Returns True if there are remaining sequences to process.
686
        """
687
688
689
690
691
692
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
693

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
            # Even if num_lookahead_slots is zero, we want to run the
            # proposer model as it may have KV.
            #
            # We run the proposer once per lookahead slot. In the future we
            # should delegate how many times it runs to the proposer.
            for _ in range(max(num_lookahead_slots, 1)):
                self.proposer_worker.execute_model()

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
710
711
            if data["run_spec_proposer_for_prefill"]:
                self.proposer_worker.execute_model()
712

713
        return True
714

715
716
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
717
718
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
719
720
721
722
723
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

724
725
726
727
        When `enable_chunked_prefill` is set, scorer will batch decodes and 
        prefills, while proposer will sync its KV-cache by running an extra
        forward on prefills.

728
729
730
        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
731
732
        # With prefill chunking, expect requests to have prompts first
        # so that backend gets prefill|decode.
733
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
734

735
736
737
738
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

739
740
741
742
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
743

744
745
746
747
748
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")

749
750
        execute_model_req.previous_hidden_states = None

751
752
753
754
755
756
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
        _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
            execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
        # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
        # discard decodes that have already been processed by proposer.
        non_spec_indices = [
            idx for idx in non_spec_indices
            if execute_model_req.seq_group_metadata_list[idx].is_prompt
        ]
        if len(non_spec_indices):
            all_hidden_states = proposal_scores.hidden_states
            if all_hidden_states is not None:
                prefill_hidden_states = all_hidden_states[non_spec_indices]
                execute_model_req.previous_hidden_states = \
                    prepare_prefill_hidden_states(prefill_hidden_states)
            # Sync proposer KV cache for prefills.
            prefill_req = execute_model_req.clone(non_spec_seqs)
773
            # TODO avoid sampling here?
774
775
            self.proposer_worker.execute_model(prefill_req)

776
777
778
779
780
781
782
783
        with Timer() as verification_timer:
            accepted_token_ids, target_logprobs = self._verify_tokens(
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
784

785
        return self._create_output_sampler_list(
786
            execute_model_req.seq_group_metadata_list,
787
788
            accepted_token_ids,
            target_logprobs=target_logprobs,
789
790
            prompt_logprobs=proposal_scores.prompt_logprobs
            if not self._disable_logprobs else None,
791
792
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
793
794
795
796
797
798
799
800

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
801
    ) -> Tuple[torch.Tensor, torch.Tensor]:
802
803
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
804
805
806

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
807
808
809
810
811
812
813
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
814
815
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
816
817
        original_indices = spec_indices + non_spec_indices

818
819
        # Get probabilities of target model, including bonus tokens.
        proposal_verifier_probs = proposal_scores.probs[spec_indices]
820
821

        # Get non-speculative sampled tokens from target model.
822
823
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

824
825
826
827
828
829
830
831
832
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

833
        # Sampler arguments
834
835
836
837
838
839
840
841
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
842

843
        accepted_token_ids = self.spec_decode_sampler(
844
            target_with_bonus_probs=proposal_verifier_probs,
845
846
847
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
848
            **sampler_extra_kwargs,
849
850
851
852
853
854
855
856
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
857
        logprobs = proposal_scores.logprobs
858
859
860
861
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

862
        # B x K+1 x D
863
864
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
865
866
867
868
869
            # Only get terminal hidden states for next step
            terminal_metadata = [
                sg for sg in seq_group_metadata_list if sg.do_sample
            ]

870
            # Contract hidden states based on accepted tokens
871
            hs_size = hidden_states.shape[-1]
872
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
873
874
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)  # b
            # Drop non-terminal prefill chunks hidden states.
875
876
877
878
            hidden_states = hidden_states[accepted_index !=
                                          VLLM_INVALID_TOKEN_ID]
            accepted_index = accepted_index[accepted_index !=
                                            VLLM_INVALID_TOKEN_ID]
879
880
881
882
            assert len(accepted_index) == hidden_states.shape[0] == len(
                terminal_metadata)
            index = accepted_index[:, None, None].expand(-1, 1,
                                                         hs_size)  # b x 1 x d
883
            second_last_token_hidden_states = hidden_states[:, -2]  # b x d
884
885
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
886
            self.previous_hidden_states = HiddenStates(
887
                hidden_states, terminal_metadata,
888
                second_last_token_hidden_states)
889
        return accepted_token_ids, logprobs
890
891
892
893
894

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
895
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
896
897
        prompt_logprobs: Optional[
            torch.Tensor],  # shape: [nprompt_tokens, vocab_size]
898
        k: int,
899
        stage_times: Tuple[float, float, float],
900
901
902
903
904
905
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
906
907
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
928
929
930

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
931
932
933
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

934
935
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

936
        # Serialize tensor to CPU Python list.
937
938
939
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
940
        # Non-terminal prefill chunks will end up here as rows with just -1s
941
942
        # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
        # terminal chunks will only have one generated token at time 0.
943
        sampler_output_list: List[SamplerOutput] = []
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

        # Prefills are not multi-step (return at most 1 token), in order to
        # avoid padding or repetition to fit decodes, we separate them.
        for i, sg in enumerate(seq_group_metadata_list):
            if not sg.is_prompt:
                # Requests are ordered as prefills|decodes=>no more prefills.
                break
            num_logprobs = num_logprobs_per_seq[i]
            seq_kwargs = dict(token_id=-1,
                              token_id_logprob_rank=0,
                              token_id_logprob=-float('inf'),
                              topk_token_ids=[-1] * num_logprobs,
                              topk_logprobs=[-float('inf')] * num_logprobs,
                              seq_id=seq_ids[i])
            # Terminal chunk, has token.
            if sg.do_sample:
                seq_kwargs.update(
                    dict(
                        token_id=accepted_token_ids[i][0].item(),
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            0][i],
                        token_id_logprob=accepted_token_id_logprobs_by_step[0]
                        [i],
                        topk_token_ids=topk_indices_by_step[0][i]
                        [:num_logprobs],
                        # output only so step is 0
                        topk_logprobs=topk_logprobs_by_step[0][i]
                        [:num_logprobs],
                    ))
            needs_plogs = (sg.sampling_params.prompt_logprobs
                           and sg.sampling_params.prompt_logprobs > 0)
            plogs = None
            if prompt_logprobs is not None:
                # Even non-terminal prompt chunks can have logprobs here.
                plogs = prompt_logprobs[i]
            elif needs_plogs:
                # Prompt logprobs are requested but `_disable_logprobs` is set.
                seq_data = next(iter(sg.seq_data.values()))
                # Get only the tokens in this chunk!
                prompt_token_ids = seq_data.get_prompt_token_ids()
                prompt_token_ids = prompt_token_ids[
                    seq_data.
                    _num_computed_tokens:seq_data._num_computed_tokens +
                    sg.token_chunk_size]

                is_first_chunk = seq_data._num_computed_tokens == 0
                # There's no prob generated for the first token in a sequence.
                if is_first_chunk:
                    prompt_token_ids = prompt_token_ids[1:]
                plogs = [
                    create_logprobs_output(
                        token_id=p_token_id,
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
                        topk_token_ids=[],
                        topk_logprobs=[],
                    ) for p_token_id in prompt_token_ids
                ]
            seq_kwargs.update(dict(prompt_logprobs=plogs))

            sampler_output_list.append(
                SamplerOutput(
                    outputs=[create_sequence_group_output(
                        **seq_kwargs)]))  # type: ignore

        # Decodes, create one SamplerOutput per-step (at most K+1).
1010
        for step_index in range(num_steps):
1011
1012
1013
1014
            if all(token_id == -1 for sg, token_id in zip(
                    seq_group_metadata_list,
                    accepted_token_ids_by_step[step_index])
                   if not sg.is_prompt):
1015
1016
                break

1017
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
1018
            for sequence_index in range(batch_size):
1019
1020
1021
1022
1023
                seq_meta = seq_group_metadata_list[sequence_index]
                # Prompts already processed above.
                if seq_meta.is_prompt:
                    continue

1024
1025
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
1026
                step_output_token_ids.append(
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
1039
1040
1041
1042
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

1043
1044
1045
1046
1047
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
1048
1049
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
1050
1051
1052
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
1053
1054
1055
1056
1057

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)
1058
1059
        # First `n_prefills` entries will contain prefills SamplerOutput when
        # chunked prefill is enabled, the rest is decodes in multi-step format.
1060
1061
        return sampler_output_list

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1227
1228
1229
1230
    @property
    def _driver_rank(self) -> int:
        return 0

1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1241
    def start_profile(self):
1242
        if isinstance(self.scorer_worker, WorkerBase):
1243
1244
1245
            self.scorer_worker.start_profile()

    def stop_profile(self):
1246
        if isinstance(self.scorer_worker, WorkerBase):
1247
1248
            self.scorer_worker.stop_profile()

1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None