spec_decode_worker.py 53.3 KB
Newer Older
1
import copy
2
from collections import defaultdict
3
from functools import cached_property
4
from typing import Any, Dict, List, Optional, Set, Tuple, Type
5
6
7

import torch

8
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
9
from vllm.distributed.communication_op import broadcast_tensor_dict
10
from vllm.logger import init_logger
11
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
from vllm.model_executor.layers.spec_decode_base_sampler import (
14
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
15
16
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
17
18
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
19
                           HiddenStates, SequenceGroupMetadata,
20
                           get_all_seq_ids_and_request_ids)
21
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
22
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
23
24
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
25
from vllm.spec_decode.medusa_worker import MedusaWorker
26
from vllm.spec_decode.metrics import AsyncMetricsCollector
27
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
28
from vllm.spec_decode.mqa_scorer import MQAScorer
29
from vllm.spec_decode.multi_step_worker import MultiStepWorker
30
from vllm.spec_decode.ngram_worker import NGramWorker
31
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
32
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
33
from vllm.spec_decode.target_model_runner import TargetModelRunner
34
35
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
36
                                   get_all_num_logprobs,
37
                                   get_sampled_token_logprobs, nvtx_range,
38
                                   split_batch_by_proposal_len)
39
from vllm.worker.worker import Worker
40
41
42
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
43
44


45
46
47
48
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
49
50
    vllm_config: VllmConfig = kwargs.get("vllm_config")
    speculative_config: SpeculativeConfig = vllm_config.speculative_config
51
52
    assert speculative_config is not None

53
54
55
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
56
    target_worker = Worker(*args, **kwargs)
57
58
59
60
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
61

62
63
    draft_worker_config = copy.deepcopy(vllm_config)
    draft_worker_config.model_config = speculative_config.draft_model_config
64
65
66
67
    draft_worker_config.quant_config = VllmConfig._get_quantization_config(
        draft_worker_config.model_config,
        vllm_config.load_config,
    )
68
69
70
    draft_worker_config.parallel_config = speculative_config.draft_parallel_config  # noqa
    # TODO allow draft-model specific load config.

71
72
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
73
        vllm_config=draft_worker_config,
74
75
76
77
78
79
80
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
81
        disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
82
83
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
84
85
86
87
88
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
89
        typical_acceptance_sampler_posterior_alpha,
90
91
92
        disable_logprobs=speculative_config.disable_logprobs,
        disable_log_stats=speculative_config.disable_log_stats,
    )
93
94
95
96

    return spec_decode_worker


97
98
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
99
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
100
101
102
103
104
105
106
107
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

108
    See https://github.com/vllm-project/vllm/pull/2188 and
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

125
    @classmethod
126
127
    def create_worker(
        cls,
128
        scorer_worker: Worker,
129
        draft_worker_kwargs: Dict[str, Any],
130
        disable_mqa_scorer: bool,
131
        disable_by_batch_size: Optional[int],
132
133
134
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
135
        disable_logprobs: bool,
136
        disable_log_stats: bool,
137
138
    ) -> "SpecDecodeWorker":

139
        allow_zero_draft_token_step = True
140
141
142
143
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
144
145
146
        draft_model_config = draft_worker_kwargs["vllm_config"].model_config
        draft_parallel_config: ParallelConfig = draft_worker_kwargs[
            'vllm_config'].parallel_config
147
148
149
150
151
        if ngram_prompt_lookup_max > 0:
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
152
153
154
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

155
            if draft_model_config.hf_config.model_type == "mlp_speculator":
156
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
157
            elif draft_model_config.hf_config.model_type == "medusa":
158
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
159
160
161
162
            else:
                if draft_tp == 1:
                    draft_worker_kwargs[
                        "model_runner_cls"] = TP1DraftModelRunner
163
                else:
164
                    if draft_model_config.hf_config.model_type == "eagle":
165
166
167
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

168
                    allow_zero_draft_token_step = False
169
170
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

171
172
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
173

174
175
176
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

177
178
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
179
            spec_decode_sampler = RejectionSampler()
180
181
182
183
184
185
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
186
187
188
189
190
191
        logger.info(
            "[Speculative Decoding] Configuring"
            " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))

        if not disable_mqa_scorer:
            if scorer_worker.model_runner.attn_backend.get_name(
192
            ) != "FLASH_ATTN":
193
194
195
196
197
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "MQA is only available with flash attn backend.")

198
199
            if draft_model_config and \
                draft_model_config.max_model_len < \
200
201
202
203
204
205
206
207
208
209
210
211
                    scorer_worker.model_config.max_model_len:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "draft model max_model_len is smaller than the target "
                    "model max_model_len.")

            if not scorer_worker.model_runner.model_config.enforce_eager:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "target model is not running in eager mode.")
212

213
214
215
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
216
            disable_mqa_scorer=disable_mqa_scorer,
217
            disable_logprobs=disable_logprobs,
218
            disable_log_stats=disable_log_stats,
219
220
221
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
            allow_zero_draft_token_step=allow_zero_draft_token_step)
222

223
224
    def __init__(
        self,
225
        proposer_worker: ProposerWorkerBase,
226
        scorer_worker: WorkerBase,
227
        spec_decode_sampler: SpecDecodeBaseSampler,
228
        disable_mqa_scorer: bool = False,
229
230
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
231
        metrics_collector: Optional[AsyncMetricsCollector] = None,
232
        disable_by_batch_size: Optional[int] = None,
233
        allow_zero_draft_token_step: Optional[bool] = True,
234
235
236
237
238
239
240
241
242
243
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
244
245
246
247
248
249
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
250
251
            disable_mqa_scorer: If set to True, disable the MQA scorer and use
                the BatchExpansionTop1Scorer instead.
252
253
254
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
255
256
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
257
258
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
259
260
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
261
262
263
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
264
265
266
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
267
268
269
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
270
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
271
        self.spec_decode_sampler = spec_decode_sampler
272
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
273
        self._metrics = AsyncMetricsCollector(
274
            self.spec_decode_sampler
275
        ) if metrics_collector is None else metrics_collector
276
277
278
279
280
281
282
283
284
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

285
286
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
287
        # Lazy initialization.
288
        self.scorer: SpeculativeScorer
289
        self.disable_mqa_scorer = disable_mqa_scorer
290

291
292
293
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
294
        self._disable_logprobs = disable_logprobs
295
        self._disable_log_stats = disable_log_stats
296

297
    def init_device(self) -> None:
298
299
300
301
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
302
303
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
304

305
306
307
308
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

309
        self._metrics.init_gpu_tensors(self.rank)
310
311
        self.spec_decode_sampler.init_gpu_tensors(self.rank)

312
313
314
315
316
317
318
319
320
321
322
323
324
        scorer_cls: Type[SpeculativeScorer]
        if self.disable_mqa_scorer:
            scorer_cls = BatchExpansionTop1Scorer
            logger.info("[Speculative Decoding] Use batch "
                        "expansion for scoring proposals.")
        else:
            scorer_cls = MQAScorer
            logger.info(
                "[Speculative Decoding] Use MQA scorer for scoring proposals.")

        self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
                                 device=self.device,
                                 vocab_size=self._vocab_size)
325

326
327
        self._configure_model_sampler_for_spec_decode()

328
329
330
    def load_model(self, *args, **kwargs):
        pass

331
332
333
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
334
        which significantly reduces overhead of sampling during verification.
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
352
353
        (self.scorer_worker.model_runner.model.sampler.
         should_modify_greedy_probs_inplace) = True
354
        self.proposer_worker.set_include_gpu_probs_tensor()
355
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
356

357
    def determine_num_available_blocks(self) -> Tuple[int, int]:
358
359
360
361
362
363
364
365
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
366
            self.scorer_worker.determine_num_available_blocks())
367

368
        scorer_cache_block_size_bytes = (
369
            self.scorer_worker.get_cache_block_size_bytes())
370
        proposer_cache_block_size_bytes = (
371
            self.proposer_worker.get_cache_block_size_bytes())
372
373
374
375
376
377

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

378
379
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
380
381
        """Initialize the cache engine of the scorer and proposer workers.
        """
382
383
384
385
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
386
387
388

    @torch.inference_mode()
    def execute_model(
389
390
391
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
392
393
        """Perform speculative decoding on the input batch.
        """
394
395
396
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
397

398
399
400
401
402
403
404
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
405
            return []
406

407
        self._track_finished_requests(execute_model_req)
408
409
410
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        all_prompt = True
        atleast_one_prompt = False
        all_zero_spec_tokens = True
        for sgm in execute_model_req.seq_group_metadata_list:
            all_prompt = all_prompt and sgm.is_prompt
            atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
            all_zero_spec_tokens = all_zero_spec_tokens and (
                sgm.num_speculative_tokens == 0)

        if all_prompt and execute_model_req.seq_group_metadata_list:
            assert num_lookahead_slots == 0, (
                "Prompt only runs should have num_lookahead_slots equal to 0. "
                "This should never happen, please file a bug at "
                "https://github.com/vllm-project/vllm/issues")
425
426
427
428
429
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
430
431
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
432
433
        # In any of these cases, the proposer and scorer workers
        # are called normally.
434
        # We expect `num_speculative_tokens` to be None for prefills.
435
436
        no_spec = (num_lookahead_slots == 0 or disable_all_speculation
                   or all_zero_spec_tokens)
437

438
439
440
441
442
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
443
        # communication to inform them.
444
445
446
447
448
449
450

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
451
452
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
453
            no_spec=no_spec,
454
            disable_all_speculation=disable_all_speculation,
455
456
457
458
459
460
461
462
463
            # When both chunked prefill and speculative decoding are enabled
            # it is possible that the same batch contains both prefill
            # and decodes. If that happens in the scorer we run the batch
            # as one single forward pass. However, in the proposer we
            # run them as 2 different batches - one for prefill and
            # the other for decodes. The variable indicates to the non-driver
            # worker that there are prefills as part of the speculative batch
            # and hence it needs to run an extra prefill forward pass.
            run_spec_proposer_for_prefill=atleast_one_prompt,
464
465
466
467
468
469
470
471
472
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

473
        if no_spec:
474
475
476
477
478
479
480
481
482
483
484
485
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

486
487
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
488
489
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
490
491
        return (execute_model_req.running_queue_size >=
                self.disable_by_batch_size)
492
493
494
495
496
497
498
499
500
501
502
503
504
505

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
506

507
508
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
509
            sampler_output: SamplerOutput) -> List[SamplerOutput]:
510
        """
511
512
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
513
514
515
516
517
518
519
520
521
522
523
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
524
525
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
526
        """
527
528
529
530
531
532
533
534
535
536
537
538
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

539
        seq_data_entries = [
540
541
542
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
543
544
            if sg.do_sample # ignore empty token sequences
        ]
545
546
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        output_index = 0
        # Make sure the non-terminal prefill chunks are still aligned with
        # their own empty output.
        for seq_group_meta in execute_model_req.seq_group_metadata_list:
            # Since we can get chunks here, we dont always have a sampled token
            # (only on last chunk) but we still have to provide an output.
            if not seq_group_meta.do_sample:
                completion_seq_group_output_list.append(
                    CompletionSequenceGroupOutput(samples=[],
                                                  prompt_logprobs=None))
            else:
                # Sequence with output.
                seq_id, seq_data = seq_data_entries[output_index]
                needs_prompt_logprobs = seq_output_prompt_logprobs[
                    output_index]
                if needs_prompt_logprobs:
                    prompt_token_ids = seq_data.get_prompt_token_ids()
                    prompt_logprobs = [
                        create_logprobs_output(
                            token_id=p_token_id,
                            token_id_logprob_rank=-1,
                            token_id_logprob=0.0,
                            topk_token_ids=[],
                            topk_logprobs=[],
                        )
                        # no prompt logprobs for the first token
                        for p_token_id in prompt_token_ids[1:]
                    ]
                else:
                    prompt_logprobs = None
                completion_seq_group_output_list.append(
                    create_sequence_group_output(
                        token_id=sampled_token_ids_list[output_index][0],
580
581
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
582
                        seq_id=seq_id,
583
584
                        topk_token_ids=[],
                        topk_logprobs=[],
585
586
587
588
                        prompt_logprobs=prompt_logprobs))
                output_index += 1

        return [SamplerOutput(outputs=completion_seq_group_output_list)]
589

590
    @nvtx_range("spec_decode_worker._run_no_spec")
591
592
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
593
594
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
595
596
597
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
598
599
        """

600
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
601
602
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
603

604
605
606
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
607
            # remove hidden_states for prompt tokens
608
609
610
            # TODO Enable `return_hidden_states`: prefill chunks hidden states
            # are pruned by the logits processor. Also, they should be arranged
            # back into full-prefill latent. Address it to enable MLPSpeculator.
611
612
613
614
615
            if any(seq.is_prompt
                   for seq in execute_model_req.seq_group_metadata_list):
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
616
617
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
618
                    hidden_states, execute_model_req.seq_group_metadata_list)
619
620
            else:
                self.previous_hidden_states.update(
621
622
623
624
625
626
627
628
629
630
631
                    hidden_states, execute_model_req.seq_group_metadata_list)

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
632

633
634
635
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
636
                                    [sampler_output])
637

638
639
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
640
641
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
642
        sampler_output.logprobs = None
643
        return sampler_output_to_return
644

645
    def _run_non_driver_rank(self) -> bool:
646
647
648
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
649

650
        Returns True if there are remaining sequences to process.
651
        """
652
653
654
655
656
657
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
658

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
            # Even if num_lookahead_slots is zero, we want to run the
            # proposer model as it may have KV.
            #
            # We run the proposer once per lookahead slot. In the future we
            # should delegate how many times it runs to the proposer.
            for _ in range(max(num_lookahead_slots, 1)):
                self.proposer_worker.execute_model()

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
675
676
            if data["run_spec_proposer_for_prefill"]:
                self.proposer_worker.execute_model()
677

678
        return True
679

680
681
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
682
683
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
684
685
686
687
688
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

689
690
691
692
        When `enable_chunked_prefill` is set, scorer will batch decodes and 
        prefills, while proposer will sync its KV-cache by running an extra
        forward on prefills.

693
694
695
        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
696
697
        # With prefill chunking, expect requests to have prompts first
        # so that backend gets prefill|decode.
698
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
699

700
701
702
703
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

704
705
706
707
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
708

709
710
711
712
713
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")

714
715
        execute_model_req.previous_hidden_states = None

716
717
718
719
720
721
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
            execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
        # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
        # discard decodes that have already been processed by proposer.
        non_spec_indices = [
            idx for idx in non_spec_indices
            if execute_model_req.seq_group_metadata_list[idx].is_prompt
        ]
        if len(non_spec_indices):
            all_hidden_states = proposal_scores.hidden_states
            # TODO fix `return_hidden_states`, same as in `_run_no_spec`
            if all_hidden_states is not None:
                prefill_hidden_states = all_hidden_states[non_spec_indices]
                execute_model_req.previous_hidden_states = \
                    prepare_prefill_hidden_states(prefill_hidden_states)
            # Sync proposer KV cache for prefills.
            prefill_req = execute_model_req.clone(non_spec_seqs)
            self.proposer_worker.execute_model(prefill_req)

741
742
743
744
745
746
747
748
        with Timer() as verification_timer:
            accepted_token_ids, target_logprobs = self._verify_tokens(
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
749

750
        return self._create_output_sampler_list(
751
            execute_model_req.seq_group_metadata_list,
752
753
            accepted_token_ids,
            target_logprobs=target_logprobs,
754
755
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
756
757
758
759
760
761
762
763

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
764
    ) -> Tuple[torch.Tensor, torch.Tensor]:
765
766
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
767
768
769

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
770
771
772
773
774
775
776
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
777
778
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
779
780
        original_indices = spec_indices + non_spec_indices

781
782
        # Get probabilities of target model, including bonus tokens.
        proposal_verifier_probs = proposal_scores.probs[spec_indices]
783
784

        # Get non-speculative sampled tokens from target model.
785
786
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

787
788
789
790
791
792
793
794
795
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

796
        # Sampler arguments
797
798
799
800
801
802
803
804
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
805

806
        accepted_token_ids = self.spec_decode_sampler(
807
            target_with_bonus_probs=proposal_verifier_probs,
808
809
810
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
811
            **sampler_extra_kwargs,
812
813
814
815
816
817
818
819
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
820
        logprobs = proposal_scores.logprobs
821
822
823
824
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

825
826
827
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
828
829
            hs_size = hidden_states.shape[-1]

830
831
832
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
833
            second_last_token_hidden_states = hidden_states[:, -2]  # b x d
834
835
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
836
837
838
            self.previous_hidden_states = HiddenStates(
                hidden_states, seq_group_metadata_list,
                second_last_token_hidden_states)
839
        return accepted_token_ids, logprobs
840
841
842
843
844

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
845
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
846
        k: int,
847
        stage_times: Tuple[float, float, float],
848
849
850
851
852
853
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
854
855
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
876
877
878

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
879
880
881
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

882
883
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

884
        # Serialize tensor to CPU Python list.
885
886
887
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
888
889
        # Non-terminal prefill chunks will end up here as rows with just -1s
        # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
890
        sampler_output_list: List[SamplerOutput] = []
891
892
893
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
894
895
                break

896
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
897
898
899
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
900
                step_output_token_ids.append(
901
902
903
904
905
906
907
908
909
910
911
912
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
913
914
915
916
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

917
918
919
920
921
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
922
923
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
924
925
926
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
927
928
929
930
931

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)
932
933
        return sampler_output_list

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1099
1100
1101
1102
    @property
    def _driver_rank(self) -> int:
        return 0

1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1113
1114
1115
1116
1117
1118
1119
1120
    def start_profile(self):
        if isinstance(self.scorer_worker, Worker):
            self.scorer_worker.start_profile()

    def stop_profile(self):
        if isinstance(self.scorer_worker, Worker):
            self.scorer_worker.stop_profile()

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None