spec_decode_worker.py 56.8 KB
Newer Older
1
from collections import defaultdict
2
from functools import cached_property
3
from typing import Any, Dict, List, Optional, Set, Tuple
4
5
6

import torch

7
from vllm.config import ParallelConfig, SpeculativeConfig
8
from vllm.distributed.communication_op import broadcast_tensor_dict
9
from vllm.logger import init_logger
10
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
11
from vllm.model_executor.layers.sampler import SamplerOutput
12
from vllm.model_executor.layers.spec_decode_base_sampler import (
13
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
14
15
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
16
17
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
18
                           HiddenStates, SequenceGroupMetadata,
19
20
                           get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer, BatchExpansionTreeStyleScorer
21
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
22
23
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
24
from vllm.spec_decode.medusa_worker import MedusaWorker
25
from vllm.spec_decode.metrics import AsyncMetricsCollector
26
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
27
from vllm.spec_decode.multi_step_worker import MultiStepWorker
28
from vllm.spec_decode.ngram_worker import NGramWorker
29
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
30
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
31
from vllm.spec_decode.target_model_runner import TargetModelRunner
32
33
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
34
                                   get_all_num_logprobs,
35
                                   get_sampled_token_logprobs, nvtx_range,
36
                                   split_batch_by_proposal_len)
37
from vllm.worker.worker import Worker
38
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
39
40
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
41
42

logger = init_logger(__name__)
43
44


45
46
47
48
49
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
    assert "speculative_config" in kwargs
50
    speculative_config: SpeculativeConfig = kwargs.get("speculative_config")
51
52
    assert speculative_config is not None

53
54
55
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
56
    target_worker = Worker(*args, **kwargs)
57
58
59
60
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
        model_config=speculative_config.draft_model_config,
        parallel_config=speculative_config.draft_parallel_config,
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
        # TODO allow draft-model specific load config.
        #load_config=load_config,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
77
78
79
80
81
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
82
        typical_acceptance_sampler_posterior_alpha,
83
84
        disable_logprobs=speculative_config.disable_logprobs,
        disable_log_stats=speculative_config.disable_log_stats,
85
        tree_style_spec_decoding=speculative_config.tree_style_spec_decoding,
86
    )
87
88
89
90

    return spec_decode_worker


91
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
92
93
94
95
96
97
98
99
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

100
    See https://github.com/vllm-project/vllm/pull/2188 and
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

117
    @classmethod
118
119
    def create_worker(
        cls,
120
        scorer_worker: Worker,
121
122
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
123
124
125
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
126
        disable_logprobs: bool,
127
        disable_log_stats: bool,
128
        tree_style_spec_decoding: bool,
129
130
    ) -> "SpecDecodeWorker":

131
        allow_zero_draft_token_step = True
132
133
134
135
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
136
137
138
139
140
        if ngram_prompt_lookup_max > 0:
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
141
142
143
144
145
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

146
147
148
            if draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "mlp_speculator":
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
149
150
151
            elif draft_worker_kwargs[
                    "model_config"].hf_config.model_type == "medusa":
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
152
153
154
155
            else:
                if draft_tp == 1:
                    draft_worker_kwargs[
                        "model_runner_cls"] = TP1DraftModelRunner
156
                else:
157
158
159
160
161
                    if draft_worker_kwargs[
                            "model_config"].hf_config.model_type == "eagle":
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

162
                    allow_zero_draft_token_step = False
163
164
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

165
166
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
167

168
169
170
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

171
172
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
173
            spec_decode_sampler = RejectionSampler()
174
175
176
177
178
179
180
181
182
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
        logger.info("Configuring SpecDecodeWorker with sampler=%s",
                    type(spec_decode_sampler))

183
184
185
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
186
            disable_logprobs=disable_logprobs,
187
            disable_log_stats=disable_log_stats,
188
189
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
190
191
            allow_zero_draft_token_step=allow_zero_draft_token_step,
            tree_style_spec_decoding=tree_style_spec_decoding)
192

193
194
    def __init__(
        self,
195
        proposer_worker: ProposerWorkerBase,
196
        scorer_worker: WorkerBase,
197
        spec_decode_sampler: SpecDecodeBaseSampler,
198
199
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
200
        metrics_collector: Optional[AsyncMetricsCollector] = None,
201
        disable_by_batch_size: Optional[int] = None,
202
        allow_zero_draft_token_step: Optional[bool] = True,
203
        tree_style_spec_decoding: bool = False,
204
205
206
207
208
209
210
211
212
213
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
214
215
216
217
218
219
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
220
221
222
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
223
224
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
225
226
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
227
228
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
229
230
231
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
232
            tree_style_spec_decoding: Whether to use tree-style generation.
233
234
235
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
236
237
238
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
239
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
240
        self.spec_decode_sampler = spec_decode_sampler
241
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
242
        self._metrics = AsyncMetricsCollector(
243
            self.spec_decode_sampler
244
        ) if metrics_collector is None else metrics_collector
245
246
247
248
249
250
251
252
253
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

254
255
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
256
        # Lazy initialization.
257
        self.scorer: BatchExpansionTop1Scorer
258

259
260
261
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
262
        self.previous_logits: Optional[Logits] = None
263
        self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
264
        self._disable_logprobs = disable_logprobs
265
        self._disable_log_stats = disable_log_stats
266

267
268
        self.tree_style_spec_decoding = tree_style_spec_decoding

269
    def init_device(self) -> None:
270
271
272
273
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
274
275
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
276

277
278
279
280
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

281
        self._metrics.init_gpu_tensors(self.rank)
282
283
        self.spec_decode_sampler.init_gpu_tensors(self.rank)

284
285
286
287
288
289
290
291
292
293
        if not self.tree_style_spec_decoding:
            self.scorer = BatchExpansionTop1Scorer(
                scorer_worker=self.scorer_worker,
                device=self.device,
                vocab_size=self._vocab_size)
        else:
            self.scorer = BatchExpansionTreeStyleScorer(
                scorer_worker=self.scorer_worker,
                device=self.device,
                vocab_size=self._vocab_size)
294

295
296
        self._configure_model_sampler_for_spec_decode()

297
298
299
    def load_model(self, *args, **kwargs):
        pass

300
301
302
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
303
        which significantly reduces overhead of sampling during verification.
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
321
322
323
324
325
        
        # tree_style decoding modify probs in _verify_tokens
        if not self.tree_style_spec_decoding:
            (self.scorer_worker.model_runner.model.sampler.
            should_modify_greedy_probs_inplace) = True
326
        self.proposer_worker.set_include_gpu_probs_tensor()
327
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
328

329
    def determine_num_available_blocks(self) -> Tuple[int, int]:
330
331
332
333
334
335
336
337
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
338
            self.scorer_worker.determine_num_available_blocks())
339

340
        scorer_cache_block_size_bytes = (
341
            self.scorer_worker.get_cache_block_size_bytes())
342
        proposer_cache_block_size_bytes = (
343
            self.proposer_worker.get_cache_block_size_bytes())
344
345
346
347
348
349

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

350
351
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
352
353
        """Initialize the cache engine of the scorer and proposer workers.
        """
354
355
356
357
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
358
359
360

    @torch.inference_mode()
    def execute_model(
361
362
363
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
364
365
        """Perform speculative decoding on the input batch.
        """
366
367
368
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
369

370
371
372
373
374
375
376
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
377
            return []
378

379
        self._track_finished_requests(execute_model_req)
380
381
382
383
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots

384
385
386
387
388
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
389
390
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
391
392
        # In any of these cases, the proposer and scorer workers
        # are called normally.
393
394
395
        no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
            sgm.num_speculative_tokens == 0
            for sgm in execute_model_req.seq_group_metadata_list)
396

397
398
399
400
401
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
402
        # communication to inform them.
403
404
405
406
407
408
409

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
410
411
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
412
            no_spec=no_spec,
413
414
415
416
417
418
419
420
421
422
            disable_all_speculation=disable_all_speculation,
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

423
        if no_spec:
424
425
426
427
428
429
430
431
432
433
434
435
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

436
437
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
438
439
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
440
441
        return (execute_model_req.running_queue_size >=
                self.disable_by_batch_size)
442
443
444
445
446
447
448
449
450
451
452
453
454
455

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
456

457
458
459
460
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
            sampler_output: SamplerOutput) -> SamplerOutput:
        """
461
462
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
463
464
465
466
467
468
469
470
471
472
473
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
474
475
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
476
        """
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

        seq_data_entries = (
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
        )
494
495
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        for index, ((seq_id, seq_data), needs_prompt_logprobs) in \
            enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)):
            if needs_prompt_logprobs:
                prompt_token_ids = seq_data.get_prompt_token_ids()
                prompt_logprobs = [
                    create_logprobs_output(
                        token_id=p_token_id,
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
                        topk_token_ids=[],
                        topk_logprobs=[],
                    )
                    # no prompt logprobs for the first token
                    for p_token_id in prompt_token_ids[1:]
                ]
            else:
                prompt_logprobs = None

514
515
516
517
518
519
520
521
            completion_seq_group_output_list.append(
                create_sequence_group_output(
                    token_id=sampled_token_ids_list[index][0],
                    token_id_logprob_rank=-1,
                    token_id_logprob=0.0,
                    seq_id=seq_id,
                    topk_token_ids=[],
                    topk_logprobs=[],
522
                    prompt_logprobs=prompt_logprobs))
523
524
        return SamplerOutput(outputs=completion_seq_group_output_list)

525
    @nvtx_range("spec_decode_worker._run_no_spec")
526
527
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
528
529
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
530
531
532
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
533
534
        """

535
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
536
537
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
538

539
540
541
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
542
543
544
545
546
547
            # remove hidden_states for prompt tokens
            if any(seq.is_prompt
                   for seq in execute_model_req.seq_group_metadata_list):
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
548
549
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
550
                    hidden_states, execute_model_req.seq_group_metadata_list)
551
552
            else:
                self.previous_hidden_states.update(
553
                    hidden_states, execute_model_req.seq_group_metadata_list)
554
555
                
        # Store logits from target model execution.
556
557
558
559
560
561
562
563
564
        if self.tree_style_spec_decoding:
            logits = sampler_output.logits
            if logits is not None:
                if self.previous_logits is None:
                    self.previous_logits = Logits(
                        logits, execute_model_req.seq_group_metadata_list)
                else:
                    self.previous_logits.update(
                        logits, execute_model_req.seq_group_metadata_list)
565
566
567
568
569
570
571
572
573
574

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
575

576
577
578
579
580
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
                                    sampler_output)

581
582
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
583
584
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
585
        sampler_output.logprobs = None
586
        return [sampler_output_to_return]
587

588
    def _run_non_driver_rank(self) -> bool:
589
590
591
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
592

593
        Returns True if there are remaining sequences to process.
594
        """
595
596
597
598
599
600
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
601

602
603
604
605
606
607
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
608
609
610
611
612
613
614
615
616
617
618
            if not self.tree_style_spec_decoding:
                # Even if num_lookahead_slots is zero, we want to run the
                # proposer model as it may have KV.
                #
                # We run the proposer once per lookahead slot. In the future we
                # should delegate how many times it runs to the proposer.
                for _ in range(max(num_lookahead_slots, 1)):
                    self.proposer_worker.execute_model()
            else:
                if not data["no_spec"]:
                    self.proposer_worker.sampler_output(None, None, None)
619
620
621

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
622

623
        return True
624

625
626
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
627
628
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
629
630
631
632
633
634
635
636
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
637
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
638

639
640
641
642
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

643
644
645
646
        # Pass last logits from target model to proposer
        execute_model_req.previous_logits = self.previous_logits
        self.previous_logits = None

647
648
649
        execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
        self.kvcache_slot_to_be_moved = None

650
651
652
653
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
654

655
656
657
658
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")
659
660
661
662
663
        
        # Pass tree attention mask and postions to target model
        if self.tree_style_spec_decoding:
            execute_model_req.tree_attn_masks = proposals.tree_attn_masks
            execute_model_req.tree_position_ids = proposals.tree_position_ids
664

665
666
        execute_model_req.previous_hidden_states = None

667
668
669
670
671
672
673
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

        with Timer() as verification_timer:
674
            accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
675
676
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)
677
678
679
680
            
            # move kv_caches of selected tokens to right positions
            if self.tree_style_spec_decoding:
                self.move_caches(execute_model_req, select_indices_list, accept_lengths)
681
682
683
684

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
685

686
        return self._create_output_sampler_list(
687
            execute_model_req.seq_group_metadata_list,
688
689
            accepted_token_ids,
            target_logprobs=target_logprobs,
690
691
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
692
693
694
695
696
697
698
699

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
700
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
701
702
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
703
704
705

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
706
707
708
709
710
711
712
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
713
714
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
715
716
        original_indices = spec_indices + non_spec_indices

717
        # Get probabilities of target model, including bonus tokens.
718
719
720
721
        if non_spec_indices:
            proposal_verifier_probs = proposal_scores.probs[spec_indices]
        else:
            proposal_verifier_probs = proposal_scores.probs
722

723
724
725
726
        if self.tree_style_spec_decoding:
            retrieve_indices = proposals.retrieve_indices
            proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]

727
        # Get non-speculative sampled tokens from target model.
728
729
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

730
        # Get bonus tokens from target model.
731
732
733
        bonus_token_ids = proposal_scores.token_ids[:, -1:]
        if non_spec_indices:
            bonus_token_ids = bonus_token_ids[spec_indices, :]
734
735

        # Get probabilities according to proposal method.
736
737
738
        proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
        if non_spec_indices:
            proposal_probs = proposal_probs[spec_indices]
739
740

        # Get proposed tokens.
741
742
743
        proposal_token_ids = proposals.proposal_token_ids
        if non_spec_indices:
            proposal_token_ids = proposal_token_ids[spec_indices] 
744

745
        # Get tree buffers.
746
747
748
        cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
        if non_spec_indices:
            cart_candidates = cart_candidates[spec_indices] 
749

750
        # Sampler arguments
751
752
753
754
755
756
757
758
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
759

760
761
762
763
764
765
766
767
        if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
            sampler_extra_kwargs["cart_candidates"] = cart_candidates
            sampler_extra_kwargs["best_candidates"] = []
            sampler_extra_kwargs["accept_lengths"] = []

            first_step_flags = []
            for i, sgm in enumerate(seq_group_metadata_list):
                seq = next(iter(sgm.seq_data.values()))
768
                first_step_flags.append(True if seq.get_first_step_flag() else False)
769
770
771
            
            sampler_extra_kwargs["first_step_flags"] = first_step_flags

772
        accepted_token_ids = self.spec_decode_sampler(
773
            target_with_bonus_probs=proposal_verifier_probs,
774
775
776
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
777
            **sampler_extra_kwargs,
778
779
780
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
781
782
783
784
785
786
        if not self.tree_style_spec_decoding:
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                        1).clone()
        else:
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()

787
788
789
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
790
        logprobs = proposal_scores.logprobs
791
792
793
794
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

795
        hidden_states = proposal_scores.hidden_states
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823

        select_indices = None
        accept_lengths = None

        select_indices_list = []

        if cart_candidates is None:
            if hidden_states is not None:
                # Contract hidden states based on accepted tokens
                hs_size = hidden_states.shape[-1]

                accepted_index = accepted_token_ids + 1  # Convert -1 to 0
                accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
                index = accepted_index[:, None, None].expand(-1, 1, hs_size)
                second_last_token_hidden_states = hidden_states[:, -2]  # b x d
                hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
                # Store hidden states from target model for subsequent decode step
                self.previous_hidden_states = HiddenStates(
                    hidden_states, seq_group_metadata_list,
                    second_last_token_hidden_states)   
        else:
            retrieve_indices = proposals.retrieve_indices

            batch_size = len(seq_group_metadata_list)

            best_candidates = sampler_extra_kwargs["best_candidates"]
            accept_lengths = sampler_extra_kwargs["accept_lengths"]

824
            # Contract hidden states based on accepted tokens
825
            hs_size = hidden_states.shape[-1]
826
827
828
829
830
831
832
833
834
835
            hidden_states = hidden_states.view(batch_size, -1, hs_size)

            # Store logits from target model for subsequent proposal
            logits = proposal_scores.logits
            logits = logits.view(batch_size, -1, logits.shape[-1])
            logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]

            previous_logits_list = []

            previous_hidden_state_list = []
836
837

            retrieve_indices = retrieve_indices.cpu()
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
            
            for i in range(batch_size):
                logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
                previous_logits_list.append(logit)
                select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
                hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
                select_indices_list.append(select_indices)
                previous_hidden_state_list.append(hidden_state)

            logits = torch.cat(previous_logits_list, dim=0)
            self.previous_logits = Logits(logits, seq_group_metadata_list)

            hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
            self.previous_hidden_states = HiddenStates(hidden_states, 
                                                       seq_group_metadata_list,)

        return accepted_token_ids, logprobs, select_indices_list, accept_lengths
    
    def move_caches(self, execute_model_req: ExecuteModelRequest, 
                    select_indices_list: List[torch.Tensor], 
                    accept_lengths: List[int]):
        """Given selected output tokens and accept length,
        move kv_caches of selected tokens to right positions.
        """
        seq_lens = []
        for sg in execute_model_req.seq_group_metadata_list:
            seq_ids = list(sg.seq_data.keys())
            
            for seq_id in seq_ids:
                seq_data = sg.seq_data[seq_id]
                seq_len = seq_data.get_len()
                token_chunk_size = sg.token_chunk_size
                context_len = seq_len - 1
                seq_len = min(seq_len, context_len + token_chunk_size)

                # first step of tree-style decoding need to ignore first generated token
874
                if seq_data.get_first_step_flag():
875
                    seq_len -= 1
876
877
878

                # move cache is the last step of tree decoding, so set first_step_flag to false
                seq_data.set_first_step_flag(False)   
879
880
881
882
                seq_lens.append(seq_len)

        model_input = self.scorer._scorer_worker.model_input
        block_tables = None
883
884
        if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
            block_tables = model_input.attn_metadata.block_tables_list
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900

        if block_tables is None:
            raise RuntimeError("Can not get block_tables from model_input.")

        cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
        block_size = cache_engine.block_size
        batch_size = len(select_indices_list)
        block_table_stride = len(block_tables) // batch_size

        select_indices_slot_mapping = []
        target_slot_mapping = []
        for i in range(batch_size):
            accept_legth = accept_lengths[i]

            if accept_legth > 0:
                select_indices = select_indices_list[i][1:] + seq_lens[i]
901
                select_indices = select_indices.tolist()
902
903
904
905
                self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
                                            select_indices, block_size, block_tables)

                target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
906
                target_indices = target_indices.tolist()
907
908
909
910
911
912
913
914
915
916
917
                self.compute_slot_mapping(target_slot_mapping, i*block_table_stride, 
                                            target_indices, block_size, block_tables)

        if len(select_indices_slot_mapping) >0:
            select_indices_slot_tensor = torch.tensor(select_indices_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            target_slot_mapping_tensor = torch.tensor(target_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
918

919
            self.kvcache_slot_to_be_moved = src_dst_tensor
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    
    def compute_slot_mapping(self, slot_mapping: List[int],
                         seq_id: int, select_indices: List[int], block_size: int,
                         block_tables: List[List[int]]):
        """
        Compute slot mapping.
        """
        # Mask the [0, start_idx) tokens of the prompt with
        # PAD_SLOT_ID, where start_idx is max(0, seq_len -
        # sliding_window). For example, if the prompt len is 10,
        # sliding window is 8, and block size is 4, the first two
        # tokens are masked and the slot mapping will be
        # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
        block_table = block_tables[seq_id]
        for index in select_indices:
            block_number = block_table[index // block_size]
            block_offset = index % block_size
            slot = block_number * block_size + block_offset
            slot_mapping.append(slot)
939

940
941
942
943
944

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
945
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
946
        k: int,
947
        stage_times: Tuple[float, float, float],
948
949
950
951
952
953
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
954
955
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
976
977
978

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
979
980
981
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

982
983
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

984
        # Serialize tensor to CPU Python list.
985
986
987
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
988
        sampler_output_list: List[SamplerOutput] = []
989
990
991
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
992
993
                break

994
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
995
996
997
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
998
                step_output_token_ids.append(
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
1011
1012
1013
1014
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

1015
1016
1017
1018
1019
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
1020
1021
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
1022
        if maybe_rejsample_metrics is not None and sampler_output_list:
1023
1024
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
1025
1026
1027
1028
1029
1030

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)

1031
1032
        return sampler_output_list

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1198
1199
1200
1201
    @property
    def _driver_rank(self) -> int:
        return 0

1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None