spec_decode_worker.py 54.3 KB
Newer Older
1
import copy
2
from collections import defaultdict
3
from functools import cached_property
4
from typing import Any, Dict, List, Optional, Set, Tuple, Type
5
6
7

import torch

8
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
9
from vllm.distributed.communication_op import broadcast_tensor_dict
10
from vllm.logger import init_logger
11
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
from vllm.model_executor.layers.spec_decode_base_sampler import (
14
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
15
16
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
17
from vllm.platforms import current_platform
18
19
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
20
                           HiddenStates, SequenceGroupMetadata,
21
                           get_all_seq_ids_and_request_ids)
22
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
23
24
25
26

if current_platform.is_cuda_alike():
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

27
28
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
29
from vllm.spec_decode.medusa_worker import MedusaWorker
30
from vllm.spec_decode.metrics import AsyncMetricsCollector
31
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
32
from vllm.spec_decode.mqa_scorer import MQAScorer
33
from vllm.spec_decode.multi_step_worker import MultiStepWorker
34
from vllm.spec_decode.ngram_worker import NGramWorker
35
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
36
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
37
from vllm.spec_decode.target_model_runner import TargetModelRunner
38
39
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
40
                                   get_all_num_logprobs,
41
                                   get_sampled_token_logprobs, nvtx_range,
42
                                   split_batch_by_proposal_len)
43
44
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
45
46

logger = init_logger(__name__)
47
48


49
50
51
52
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
53
54
    vllm_config: VllmConfig = kwargs.get("vllm_config")
    speculative_config: SpeculativeConfig = vllm_config.speculative_config
55
56
    assert speculative_config is not None

57
58
59
60
    if vllm_config.parallel_config.pipeline_parallel_size > 1:
        raise NotImplementedError("Speculative decoding is currently "
                                  "incompatible with pipeline parallelism")

61
62
63
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
64
65
66
    target_worker_config = copy.deepcopy(vllm_config)
    target_worker_config.parallel_config.worker_cls =\
        target_worker_config.parallel_config.sd_worker_cls
67
68
69
    cls = resolve_obj_by_qualname(
        target_worker_config.parallel_config.worker_cls)
    target_worker = cls(*args, **kwargs)
70
71
72
73
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
74

75
76
    draft_worker_config = copy.deepcopy(vllm_config)
    draft_worker_config.model_config = speculative_config.draft_model_config
77
78
79
80
    draft_worker_config.quant_config = VllmConfig._get_quantization_config(
        draft_worker_config.model_config,
        vllm_config.load_config,
    )
81
82
    speculative_config.draft_parallel_config.worker_cls =\
        draft_worker_config.parallel_config.sd_worker_cls
83
84
85
    draft_worker_config.parallel_config = speculative_config.draft_parallel_config  # noqa
    # TODO allow draft-model specific load config.

86
87
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
88
        vllm_config=draft_worker_config,
89
90
91
92
93
94
95
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
96
        disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
97
98
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
99
100
101
102
103
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
104
        typical_acceptance_sampler_posterior_alpha,
105
106
107
        disable_logprobs=speculative_config.disable_logprobs,
        disable_log_stats=speculative_config.disable_log_stats,
    )
108
109
110
111

    return spec_decode_worker


112
# Reminder: Please update docs/source/features/compatibility_matrix.md
113
# If the feature combo become valid
114
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
115
116
117
118
119
120
121
122
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

123
    See https://github.com/vllm-project/vllm/pull/2188 and
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

140
    @classmethod
141
142
    def create_worker(
        cls,
143
        scorer_worker: WorkerBase,
144
        draft_worker_kwargs: Dict[str, Any],
145
        disable_mqa_scorer: bool,
146
        disable_by_batch_size: Optional[int],
147
148
149
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
150
        disable_logprobs: bool,
151
        disable_log_stats: bool,
152
153
    ) -> "SpecDecodeWorker":

154
        allow_zero_draft_token_step = True
155
156
157
158
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
159
160
161
        draft_model_config = draft_worker_kwargs["vllm_config"].model_config
        draft_parallel_config: ParallelConfig = draft_worker_kwargs[
            'vllm_config'].parallel_config
162
        if ngram_prompt_lookup_max > 0:
163
164
            draft_worker_kwargs[
                "device_type"] = scorer_worker.device_config.device.type
165
166
167
168
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
169
170
171
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

172
            if draft_model_config.hf_config.model_type == "mlp_speculator":
173
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
174
            elif draft_model_config.hf_config.model_type == "medusa":
175
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
176
177
            else:
                if draft_tp == 1:
178
179
180
                    if current_platform.is_cuda_alike():
                        draft_worker_kwargs[
                            "model_runner_cls"] = TP1DraftModelRunner
181
                else:
182
                    if draft_model_config.hf_config.model_type == "eagle":
183
184
185
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

186
                    allow_zero_draft_token_step = False
187
188
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

189
190
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
191

192
193
194
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

195
196
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
197
            spec_decode_sampler = RejectionSampler()
198
199
200
201
202
203
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
204
205
206
207
208
209
        logger.info(
            "[Speculative Decoding] Configuring"
            " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))

        if not disable_mqa_scorer:
            if scorer_worker.model_runner.attn_backend.get_name(
210
            ) != "FLASH_ATTN":
211
212
213
214
215
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "MQA is only available with flash attn backend.")

216
217
            if draft_model_config and \
                draft_model_config.max_model_len < \
218
219
220
221
222
223
224
225
226
227
228
229
                    scorer_worker.model_config.max_model_len:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "draft model max_model_len is smaller than the target "
                    "model max_model_len.")

            if not scorer_worker.model_runner.model_config.enforce_eager:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "target model is not running in eager mode.")
230

231
232
233
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
234
            disable_mqa_scorer=disable_mqa_scorer,
235
            disable_logprobs=disable_logprobs,
236
            disable_log_stats=disable_log_stats,
237
238
239
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
            allow_zero_draft_token_step=allow_zero_draft_token_step)
240

241
242
    def __init__(
        self,
243
        proposer_worker: ProposerWorkerBase,
244
        scorer_worker: WorkerBase,
245
        spec_decode_sampler: SpecDecodeBaseSampler,
246
        disable_mqa_scorer: bool = False,
247
248
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
249
        metrics_collector: Optional[AsyncMetricsCollector] = None,
250
        disable_by_batch_size: Optional[int] = None,
251
        allow_zero_draft_token_step: Optional[bool] = True,
252
253
254
255
256
257
258
259
260
261
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
262
263
264
265
266
267
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
268
269
            disable_mqa_scorer: If set to True, disable the MQA scorer and use
                the BatchExpansionTop1Scorer instead.
270
271
272
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
273
274
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
275
276
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
277
278
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
279
280
281
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
282
283
284
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
285
286
287
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
288
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
289
        self.spec_decode_sampler = spec_decode_sampler
290
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
291
        self._metrics = AsyncMetricsCollector(
292
            self.spec_decode_sampler
293
        ) if metrics_collector is None else metrics_collector
294
295
296
297
298
299
300
301
302
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

303
304
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
305
        # Lazy initialization.
306
        self.scorer: SpeculativeScorer
307
        self.disable_mqa_scorer = disable_mqa_scorer
308

309
310
311
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
312
        self._disable_logprobs = disable_logprobs
313
        self._disable_log_stats = disable_log_stats
314

315
    def init_device(self) -> None:
316
317
318
319
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
320
321
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
322

323
324
325
326
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

327
328
329
        self._metrics.init_tensors(self.rank, device_type=self.device)
        self.spec_decode_sampler.init_tensors(self.rank,
                                              device_type=self.device)
330

331
332
333
334
335
336
337
338
339
340
341
342
343
        scorer_cls: Type[SpeculativeScorer]
        if self.disable_mqa_scorer:
            scorer_cls = BatchExpansionTop1Scorer
            logger.info("[Speculative Decoding] Use batch "
                        "expansion for scoring proposals.")
        else:
            scorer_cls = MQAScorer
            logger.info(
                "[Speculative Decoding] Use MQA scorer for scoring proposals.")

        self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
                                 device=self.device,
                                 vocab_size=self._vocab_size)
344

345
346
        self._configure_model_sampler_for_spec_decode()

347
348
349
    def load_model(self, *args, **kwargs):
        pass

350
351
352
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
353
        which significantly reduces overhead of sampling during verification.
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
371
372
        (self.scorer_worker.model_runner.model.sampler.
         should_modify_greedy_probs_inplace) = True
373
        self.proposer_worker.set_include_gpu_probs_tensor()
374
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
375

376
    def determine_num_available_blocks(self) -> Tuple[int, int]:
377
378
379
380
381
382
383
384
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
385
            self.scorer_worker.determine_num_available_blocks())
386

387
        scorer_cache_block_size_bytes = (
388
            self.scorer_worker.get_cache_block_size_bytes())
389
        proposer_cache_block_size_bytes = (
390
            self.proposer_worker.get_cache_block_size_bytes())
391
392
393
394
395
396

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

397
398
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
399
400
        """Initialize the cache engine of the scorer and proposer workers.
        """
401
402
403
404
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
405
406
407

    @torch.inference_mode()
    def execute_model(
408
409
410
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
411
412
        """Perform speculative decoding on the input batch.
        """
413
414
415
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
416

417
418
419
420
421
422
423
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
424
            return []
425

426
        self._track_finished_requests(execute_model_req)
427
428
429
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        all_prompt = True
        atleast_one_prompt = False
        all_zero_spec_tokens = True
        for sgm in execute_model_req.seq_group_metadata_list:
            all_prompt = all_prompt and sgm.is_prompt
            atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
            all_zero_spec_tokens = all_zero_spec_tokens and (
                sgm.num_speculative_tokens == 0)

        if all_prompt and execute_model_req.seq_group_metadata_list:
            assert num_lookahead_slots == 0, (
                "Prompt only runs should have num_lookahead_slots equal to 0. "
                "This should never happen, please file a bug at "
                "https://github.com/vllm-project/vllm/issues")
444
445
446
447
448
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
449
450
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
451
452
        # In any of these cases, the proposer and scorer workers
        # are called normally.
453
        # We expect `num_speculative_tokens` to be None for prefills.
454
455
        no_spec = (num_lookahead_slots == 0 or disable_all_speculation
                   or all_zero_spec_tokens)
456

457
458
459
460
461
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
462
        # communication to inform them.
463
464
465
466
467
468
469

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
470
471
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
472
            no_spec=no_spec,
473
            disable_all_speculation=disable_all_speculation,
474
475
476
477
478
479
480
481
482
            # When both chunked prefill and speculative decoding are enabled
            # it is possible that the same batch contains both prefill
            # and decodes. If that happens in the scorer we run the batch
            # as one single forward pass. However, in the proposer we
            # run them as 2 different batches - one for prefill and
            # the other for decodes. The variable indicates to the non-driver
            # worker that there are prefills as part of the speculative batch
            # and hence it needs to run an extra prefill forward pass.
            run_spec_proposer_for_prefill=atleast_one_prompt,
483
484
485
486
487
488
489
490
491
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

492
        if no_spec:
493
494
495
496
497
498
499
500
501
502
503
504
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

505
506
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
507
508
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
509
510
        return (execute_model_req.running_queue_size >=
                self.disable_by_batch_size)
511
512
513
514
515
516
517
518
519
520
521
522
523
524

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
525

526
527
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
528
            sampler_output: SamplerOutput) -> List[SamplerOutput]:
529
        """
530
531
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
532
533
534
535
536
537
538
539
540
541
542
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
543
544
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
545
        """
546
547
548
549
550
551
552
553
554
555
556
557
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

558
        seq_data_entries = [
559
560
561
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
562
563
            if sg.do_sample # ignore empty token sequences
        ]
564
565
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
        output_index = 0
        # Make sure the non-terminal prefill chunks are still aligned with
        # their own empty output.
        for seq_group_meta in execute_model_req.seq_group_metadata_list:
            # Since we can get chunks here, we dont always have a sampled token
            # (only on last chunk) but we still have to provide an output.
            if not seq_group_meta.do_sample:
                completion_seq_group_output_list.append(
                    CompletionSequenceGroupOutput(samples=[],
                                                  prompt_logprobs=None))
            else:
                # Sequence with output.
                seq_id, seq_data = seq_data_entries[output_index]
                needs_prompt_logprobs = seq_output_prompt_logprobs[
                    output_index]
                if needs_prompt_logprobs:
                    prompt_token_ids = seq_data.get_prompt_token_ids()
                    prompt_logprobs = [
                        create_logprobs_output(
                            token_id=p_token_id,
                            token_id_logprob_rank=-1,
                            token_id_logprob=0.0,
                            topk_token_ids=[],
                            topk_logprobs=[],
                        )
                        # no prompt logprobs for the first token
                        for p_token_id in prompt_token_ids[1:]
                    ]
                else:
                    prompt_logprobs = None
                completion_seq_group_output_list.append(
                    create_sequence_group_output(
                        token_id=sampled_token_ids_list[output_index][0],
599
600
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
601
                        seq_id=seq_id,
602
603
                        topk_token_ids=[],
                        topk_logprobs=[],
604
605
606
607
                        prompt_logprobs=prompt_logprobs))
                output_index += 1

        return [SamplerOutput(outputs=completion_seq_group_output_list)]
608

609
    @nvtx_range("spec_decode_worker._run_no_spec")
610
611
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
612
613
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
614
615
616
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
617
618
        """

619
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
620
621
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
622

623
624
625
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
626
            # remove hidden_states for prompt tokens
627
628
629
            # TODO Enable `return_hidden_states`: prefill chunks hidden states
            # are pruned by the logits processor. Also, they should be arranged
            # back into full-prefill latent. Address it to enable MLPSpeculator.
630
631
632
633
634
            if any(seq.is_prompt
                   for seq in execute_model_req.seq_group_metadata_list):
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
635
636
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
637
                    hidden_states, execute_model_req.seq_group_metadata_list)
638
639
            else:
                self.previous_hidden_states.update(
640
641
642
643
644
645
646
647
648
649
650
                    hidden_states, execute_model_req.seq_group_metadata_list)

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
651

652
653
654
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
655
                                    [sampler_output])
656

657
658
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
659
660
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
661
        sampler_output.logprobs = None
662
        return sampler_output_to_return
663

664
    def _run_non_driver_rank(self) -> bool:
665
666
667
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
668

669
        Returns True if there are remaining sequences to process.
670
        """
671
672
673
674
675
676
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
            # Even if num_lookahead_slots is zero, we want to run the
            # proposer model as it may have KV.
            #
            # We run the proposer once per lookahead slot. In the future we
            # should delegate how many times it runs to the proposer.
            for _ in range(max(num_lookahead_slots, 1)):
                self.proposer_worker.execute_model()

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
694
695
            if data["run_spec_proposer_for_prefill"]:
                self.proposer_worker.execute_model()
696

697
        return True
698

699
700
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
701
702
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
703
704
705
706
707
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

708
709
710
711
        When `enable_chunked_prefill` is set, scorer will batch decodes and 
        prefills, while proposer will sync its KV-cache by running an extra
        forward on prefills.

712
713
714
        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
715
716
        # With prefill chunking, expect requests to have prompts first
        # so that backend gets prefill|decode.
717
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
718

719
720
721
722
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

723
724
725
726
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
727

728
729
730
731
732
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")

733
734
        execute_model_req.previous_hidden_states = None

735
736
737
738
739
740
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
            execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
        # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
        # discard decodes that have already been processed by proposer.
        non_spec_indices = [
            idx for idx in non_spec_indices
            if execute_model_req.seq_group_metadata_list[idx].is_prompt
        ]
        if len(non_spec_indices):
            all_hidden_states = proposal_scores.hidden_states
            # TODO fix `return_hidden_states`, same as in `_run_no_spec`
            if all_hidden_states is not None:
                prefill_hidden_states = all_hidden_states[non_spec_indices]
                execute_model_req.previous_hidden_states = \
                    prepare_prefill_hidden_states(prefill_hidden_states)
            # Sync proposer KV cache for prefills.
            prefill_req = execute_model_req.clone(non_spec_seqs)
            self.proposer_worker.execute_model(prefill_req)

760
761
762
763
764
765
766
767
        with Timer() as verification_timer:
            accepted_token_ids, target_logprobs = self._verify_tokens(
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
768

769
        return self._create_output_sampler_list(
770
            execute_model_req.seq_group_metadata_list,
771
772
            accepted_token_ids,
            target_logprobs=target_logprobs,
773
774
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
775
776
777
778
779
780
781
782

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
783
    ) -> Tuple[torch.Tensor, torch.Tensor]:
784
785
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
786
787
788

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
789
790
791
792
793
794
795
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
796
797
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
798
799
        original_indices = spec_indices + non_spec_indices

800
801
        # Get probabilities of target model, including bonus tokens.
        proposal_verifier_probs = proposal_scores.probs[spec_indices]
802
803

        # Get non-speculative sampled tokens from target model.
804
805
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

806
807
808
809
810
811
812
813
814
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

815
        # Sampler arguments
816
817
818
819
820
821
822
823
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
824

825
        accepted_token_ids = self.spec_decode_sampler(
826
            target_with_bonus_probs=proposal_verifier_probs,
827
828
829
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
830
            **sampler_extra_kwargs,
831
832
833
834
835
836
837
838
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
839
        logprobs = proposal_scores.logprobs
840
841
842
843
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

844
845
846
        hidden_states = proposal_scores.hidden_states
        if hidden_states is not None:
            # Contract hidden states based on accepted tokens
847
848
            hs_size = hidden_states.shape[-1]

849
850
851
            accepted_index = accepted_token_ids + 1  # Convert -1 to 0
            accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
            index = accepted_index[:, None, None].expand(-1, 1, hs_size)
852
            second_last_token_hidden_states = hidden_states[:, -2]  # b x d
853
854
            hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
            # Store hidden states from target model for subsequent decode step
855
856
857
            self.previous_hidden_states = HiddenStates(
                hidden_states, seq_group_metadata_list,
                second_last_token_hidden_states)
858
        return accepted_token_ids, logprobs
859
860
861
862
863

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
864
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
865
        k: int,
866
        stage_times: Tuple[float, float, float],
867
868
869
870
871
872
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
873
874
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
895
896
897

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
898
899
900
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

901
902
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

903
        # Serialize tensor to CPU Python list.
904
905
906
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
907
908
        # Non-terminal prefill chunks will end up here as rows with just -1s
        # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
909
        sampler_output_list: List[SamplerOutput] = []
910
911
912
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
913
914
                break

915
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
916
917
918
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
919
                step_output_token_ids.append(
920
921
922
923
924
925
926
927
928
929
930
931
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
932
933
934
935
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

936
937
938
939
940
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
941
942
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
943
944
945
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
946
947
948
949
950

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)
951
952
        return sampler_output_list

953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1118
1119
1120
1121
    @property
    def _driver_rank(self) -> int:
        return 0

1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1132
    def start_profile(self):
1133
        if isinstance(self.scorer_worker, WorkerBase):
1134
1135
1136
            self.scorer_worker.start_profile()

    def stop_profile(self):
1137
        if isinstance(self.scorer_worker, WorkerBase):
1138
1139
            self.scorer_worker.stop_profile()

1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None