spec_decode_worker.py 65.1 KB
Newer Older
1
import os
2
import copy
3
from collections import defaultdict
4
from functools import cached_property
5
from typing import Any, Dict, List, Optional, Set, Tuple, Type
6
7
8

import torch

9
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
10
from vllm.distributed.communication_op import broadcast_tensor_dict
11
from vllm.logger import init_logger
12
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
13
from vllm.model_executor.layers.sampler import SamplerOutput
14
from vllm.model_executor.layers.spec_decode_base_sampler import (
15
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
16
17
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
18
from vllm.platforms import current_platform
19
20
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
21
                           HiddenStates, SequenceGroupMetadata,
22
                           get_all_seq_ids_and_request_ids, Logits)
zhuwenwen's avatar
zhuwenwen committed
23
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
24
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
25
26
27
28

if current_platform.is_cuda_alike():
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

29
30
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
31
from vllm.spec_decode.medusa_worker import MedusaWorker
32
from vllm.spec_decode.metrics import AsyncMetricsCollector
33
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
34
from vllm.spec_decode.mqa_scorer import MQAScorer
35
from vllm.spec_decode.multi_step_worker import MultiStepWorker
36
from vllm.spec_decode.ngram_worker import NGramWorker
37
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
38
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
39
from vllm.spec_decode.target_model_runner import TargetModelRunner
40
41
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
42
                                   get_all_num_logprobs,
43
                                   get_sampled_token_logprobs, nvtx_range,
44
                                   split_batch_by_proposal_len)
45
46
from vllm.worker.worker_base import (LoraNotSupportedWorkerBase, WorkerBase,
                                     WorkerWrapperBase)
47

48
49
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
王敏's avatar
王敏 committed
50
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
51
52

logger = init_logger(__name__)
53
54


55
56
57
58
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
59
60
    vllm_config: VllmConfig = kwargs.get("vllm_config")
    speculative_config: SpeculativeConfig = vllm_config.speculative_config
61
62
    assert speculative_config is not None

63
64
65
66
    if vllm_config.parallel_config.pipeline_parallel_size > 1:
        raise NotImplementedError("Speculative decoding is currently "
                                  "incompatible with pipeline parallelism")

67
68
69
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
70
71
72
73
74
    target_worker_config = copy.deepcopy(vllm_config)
    target_worker_config.parallel_config.worker_cls =\
        target_worker_config.parallel_config.sd_worker_cls
    target_worker = WorkerWrapperBase(vllm_config=target_worker_config)
    target_worker.init_worker(*args, **kwargs)
75
76
77
78
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
79

80
81
    draft_worker_config = copy.deepcopy(vllm_config)
    draft_worker_config.model_config = speculative_config.draft_model_config
82
83
84
85
    draft_worker_config.quant_config = VllmConfig._get_quantization_config(
        draft_worker_config.model_config,
        vllm_config.load_config,
    )
86
87
    speculative_config.draft_parallel_config.worker_cls =\
        draft_worker_config.parallel_config.sd_worker_cls
88
89
90
    draft_worker_config.parallel_config = speculative_config.draft_parallel_config  # noqa
    # TODO allow draft-model specific load config.

91
92
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
93
        vllm_config=draft_worker_config,
94
95
96
97
98
99
100
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
101
        disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
102
103
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
104
105
106
107
108
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
109
        typical_acceptance_sampler_posterior_alpha,
110
        disable_logprobs=speculative_config.disable_logprobs,
111
        disable_log_stats=speculative_config.disable_log_stats)
112
113
114
115

    return spec_decode_worker


116
# Reminder: Please update docs/source/usage/compatibility_matrix.md
117
# If the feature combo become valid
118
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
119
120
121
122
123
124
125
126
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

127
    See https://github.com/vllm-project/vllm/pull/2188 and
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

144
    @classmethod
145
146
    def create_worker(
        cls,
147
        scorer_worker: WorkerBase,
148
        draft_worker_kwargs: Dict[str, Any],
149
        disable_mqa_scorer: bool,
150
        disable_by_batch_size: Optional[int],
151
152
153
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
154
        disable_logprobs: bool,
155
        disable_log_stats: bool,
156
157
    ) -> "SpecDecodeWorker":

158
        allow_zero_draft_token_step = True
159
160
161
162
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
163
164
165
        draft_model_config = draft_worker_kwargs["vllm_config"].model_config
        draft_parallel_config: ParallelConfig = draft_worker_kwargs[
            'vllm_config'].parallel_config
166
        if ngram_prompt_lookup_max > 0:
王敏's avatar
王敏 committed
167
168
169
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            assert draft_parallel_config.tensor_parallel_size == 1
170
171
            draft_worker_kwargs[
                "device_type"] = scorer_worker.device_config.device.type
172
173
174
175
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
176
177
178
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

179
            if draft_model_config.hf_config.model_type == "mlp_speculator":
180
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
181
            elif draft_model_config.hf_config.model_type == "medusa":
182
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
183
184
            else:
                if draft_tp == 1:
185
186
187
                    if current_platform.is_cuda_alike():
                        draft_worker_kwargs[
                            "model_runner_cls"] = TP1DraftModelRunner
188
                else:
189
                    if draft_model_config.hf_config.model_type == "eagle":
190
191
192
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

193
                    allow_zero_draft_token_step = False
194
195
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

196
197
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
198

199
200
201
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

202
203
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
204
            spec_decode_sampler = RejectionSampler()
205
206
207
208
209
210
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
211
212
213
214
215
216
        logger.info(
            "[Speculative Decoding] Configuring"
            " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))

        if not disable_mqa_scorer:
            if scorer_worker.model_runner.attn_backend.get_name(
217
            ) != "FLASH_ATTN":
218
219
220
221
222
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "MQA is only available with flash attn backend.")

223
224
            if draft_model_config and \
                draft_model_config.max_model_len < \
225
226
227
228
229
230
231
232
233
234
235
236
                    scorer_worker.model_config.max_model_len:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "draft model max_model_len is smaller than the target "
                    "model max_model_len.")

            if not scorer_worker.model_runner.model_config.enforce_eager:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "target model is not running in eager mode.")
237

238
239
240
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
241
            disable_mqa_scorer=disable_mqa_scorer,
242
            disable_logprobs=disable_logprobs,
243
            disable_log_stats=disable_log_stats,
244
245
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
246
            allow_zero_draft_token_step=allow_zero_draft_token_step)
247

248
249
    def __init__(
        self,
250
        proposer_worker: ProposerWorkerBase,
251
        scorer_worker: WorkerBase,
252
        spec_decode_sampler: SpecDecodeBaseSampler,
253
        disable_mqa_scorer: bool = False,
254
255
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
256
        metrics_collector: Optional[AsyncMetricsCollector] = None,
257
        disable_by_batch_size: Optional[int] = None,
258
        allow_zero_draft_token_step: Optional[bool] = True,
259
260
261
262
263
264
265
266
267
268
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
269
270
271
272
273
274
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
275
276
            disable_mqa_scorer: If set to True, disable the MQA scorer and use
                the BatchExpansionTop1Scorer instead.
277
278
279
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
280
281
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
282
283
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
284
285
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
286
287
288
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
289
290
291
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
292
293
294
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
295
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
296
        self.spec_decode_sampler = spec_decode_sampler
297
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
298
        self._metrics = AsyncMetricsCollector(
299
            self.spec_decode_sampler
300
        ) if metrics_collector is None else metrics_collector
301
302
303
304
305
306
307
308
309
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

310
311
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
312
        # Lazy initialization.
313
        self.scorer: BatchExpansionTop1Scorer
314
        self.disable_mqa_scorer = disable_mqa_scorer
315

316
317
318
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
319
        self.previous_logits: Optional[Logits] = None
320
        self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
321
        self._disable_logprobs = disable_logprobs
322
        self._disable_log_stats = disable_log_stats
323

324
        self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
325

326
    def init_device(self) -> None:
327
328
329
330
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
331
332
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
333

334
335
336
337
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

338
339
340
        self._metrics.init_tensors(self.rank, device_type=self.device)
        self.spec_decode_sampler.init_tensors(self.rank,
                                              device_type=self.device)
zhuwenwen's avatar
zhuwenwen committed
341
        
342
343
344
345
346
347
348
349
350
        scorer_cls: Type[SpeculativeScorer]
        if self.disable_mqa_scorer:
            scorer_cls = BatchExpansionTop1Scorer
            logger.info("[Speculative Decoding] Use batch "
                        "expansion for scoring proposals.")
        else:
            scorer_cls = MQAScorer
            logger.info(
                "[Speculative Decoding] Use MQA scorer for scoring proposals.")
351

352
        if not self.tree_decoding:
zhuwenwen's avatar
zhuwenwen committed
353
            self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
354
355
                                 device=self.device,
                                 vocab_size=self._vocab_size)
356
357
358
359
360
        else:
            self.scorer = BatchExpansionTreeStyleScorer(
                scorer_worker=self.scorer_worker,
                device=self.device,
                vocab_size=self._vocab_size)
361

362
363
        self._configure_model_sampler_for_spec_decode()

364
365
366
    def load_model(self, *args, **kwargs):
        pass

367
368
369
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
370
        which significantly reduces overhead of sampling during verification.
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
388
389
        
        # tree_style decoding modify probs in _verify_tokens
390
        if not self.tree_decoding:
391
392
            (self.scorer_worker.model_runner.model.sampler.
            should_modify_greedy_probs_inplace) = True
393
        self.proposer_worker.set_include_gpu_probs_tensor()
394
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
395

396
    def determine_num_available_blocks(self) -> Tuple[int, int]:
397
398
399
400
401
402
403
404
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
405
            self.scorer_worker.determine_num_available_blocks())
406

407
        scorer_cache_block_size_bytes = (
408
            self.scorer_worker.get_cache_block_size_bytes())
409
        proposer_cache_block_size_bytes = (
410
            self.proposer_worker.get_cache_block_size_bytes())
411
412
413
414
415
416

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

417
418
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
419
420
        """Initialize the cache engine of the scorer and proposer workers.
        """
421
422
423
424
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
425
426
427

    @torch.inference_mode()
    def execute_model(
428
429
430
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
431
432
        """Perform speculative decoding on the input batch.
        """
433
434
435
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
436

437
438
439
440
441
442
443
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
444
            return []
445

446
        self._track_finished_requests(execute_model_req)
447
448
449
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        all_prompt = True
        atleast_one_prompt = False
        all_zero_spec_tokens = True
        for sgm in execute_model_req.seq_group_metadata_list:
            all_prompt = all_prompt and sgm.is_prompt
            atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
            all_zero_spec_tokens = all_zero_spec_tokens and (
                sgm.num_speculative_tokens == 0)

        if all_prompt and execute_model_req.seq_group_metadata_list:
            assert num_lookahead_slots == 0, (
                "Prompt only runs should have num_lookahead_slots equal to 0. "
                "This should never happen, please file a bug at "
                "https://github.com/vllm-project/vllm/issues")
464
465
466
467
468
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
469
470
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
471
472
        # In any of these cases, the proposer and scorer workers
        # are called normally.
473
        # We expect `num_speculative_tokens` to be None for prefills.
474
475
        no_spec = (num_lookahead_slots == 0 or disable_all_speculation
                   or all_zero_spec_tokens)
476

477
478
479
480
481
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
482
        # communication to inform them.
483
484
485
486
487
488
489

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
490
491
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
492
            no_spec=no_spec,
493
            disable_all_speculation=disable_all_speculation,
494
495
496
497
498
499
500
501
502
            # When both chunked prefill and speculative decoding are enabled
            # it is possible that the same batch contains both prefill
            # and decodes. If that happens in the scorer we run the batch
            # as one single forward pass. However, in the proposer we
            # run them as 2 different batches - one for prefill and
            # the other for decodes. The variable indicates to the non-driver
            # worker that there are prefills as part of the speculative batch
            # and hence it needs to run an extra prefill forward pass.
            run_spec_proposer_for_prefill=atleast_one_prompt,
503
504
505
506
507
508
509
510
511
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

512
        if no_spec:
513
514
515
516
517
518
519
520
521
522
523
524
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

525
526
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
527
528
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
529
530
        return (execute_model_req.running_queue_size >=
                self.disable_by_batch_size)
531
532
533
534
535
536
537
538
539
540
541
542
543
544

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
545

546
547
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
548
            sampler_output: SamplerOutput) -> List[SamplerOutput]:
549
        """
550
551
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
552
553
554
555
556
557
558
559
560
561
562
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
563
564
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
565
        """
566
567
568
569
570
571
572
573
574
575
576
577
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

578
        seq_data_entries = [
579
580
581
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
582
583
            if sg.do_sample # ignore empty token sequences
        ]
584
585
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        output_index = 0
        # Make sure the non-terminal prefill chunks are still aligned with
        # their own empty output.
        for seq_group_meta in execute_model_req.seq_group_metadata_list:
            # Since we can get chunks here, we dont always have a sampled token
            # (only on last chunk) but we still have to provide an output.
            if not seq_group_meta.do_sample:
                completion_seq_group_output_list.append(
                    CompletionSequenceGroupOutput(samples=[],
                                                  prompt_logprobs=None))
            else:
                # Sequence with output.
                seq_id, seq_data = seq_data_entries[output_index]
                needs_prompt_logprobs = seq_output_prompt_logprobs[
                    output_index]
                if needs_prompt_logprobs:
                    prompt_token_ids = seq_data.get_prompt_token_ids()
                    prompt_logprobs = [
                        create_logprobs_output(
                            token_id=p_token_id,
                            token_id_logprob_rank=-1,
                            token_id_logprob=0.0,
                            topk_token_ids=[],
                            topk_logprobs=[],
                        )
                        # no prompt logprobs for the first token
                        for p_token_id in prompt_token_ids[1:]
                    ]
                else:
                    prompt_logprobs = None
                completion_seq_group_output_list.append(
                    create_sequence_group_output(
                        token_id=sampled_token_ids_list[output_index][0],
619
620
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
621
                        seq_id=seq_id,
622
623
                        topk_token_ids=[],
                        topk_logprobs=[],
624
625
626
627
                        prompt_logprobs=prompt_logprobs))
                output_index += 1

        return [SamplerOutput(outputs=completion_seq_group_output_list)]
628

629
    @nvtx_range("spec_decode_worker._run_no_spec")
630
631
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
632
633
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
634
635
636
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
637
        """
638
        if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
王敏's avatar
王敏 committed
639
640
            execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
            self.kvcache_slot_to_be_moved = None
641

642
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
643
644
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
645

646
647
648
        # Store hidden states from target model execution.
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
649
            # remove hidden_states for prompt tokens
650
651
652
            # TODO Enable `return_hidden_states`: prefill chunks hidden states
            # are pruned by the logits processor. Also, they should be arranged
            # back into full-prefill latent. Address it to enable MLPSpeculator.
653
654
655
656
657
            if any(seq.is_prompt
                   for seq in execute_model_req.seq_group_metadata_list):
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
658
659
            if self.previous_hidden_states is None:
                self.previous_hidden_states = HiddenStates(
660
                    hidden_states, execute_model_req.seq_group_metadata_list)
661
662
            else:
                self.previous_hidden_states.update(
663
                    hidden_states, execute_model_req.seq_group_metadata_list)
664
665
                
        # Store logits from target model execution.
666
        if self.tree_decoding:
667
668
669
670
671
672
673
674
            logits = sampler_output.logits
            if logits is not None:
                if self.previous_logits is None:
                    self.previous_logits = Logits(
                        logits, execute_model_req.seq_group_metadata_list)
                else:
                    self.previous_logits.update(
                        logits, execute_model_req.seq_group_metadata_list)
675
676
677
678
679
680
681
682
683
684

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
685

686
687
688
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
689
                                    [sampler_output])
690

691
692
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
693
694
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
695
        sampler_output.logprobs = None
696
        return sampler_output_to_return
697

698
    def _run_non_driver_rank(self) -> bool:
699
700
701
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
702

703
        Returns True if there are remaining sequences to process.
704
        """
705
706
707
708
709
710
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
711

712
713
714
715
716
717
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
718
            # if not self.tree_decoding:
王敏's avatar
王敏 committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            #     # Even if num_lookahead_slots is zero, we want to run the
            #     # proposer model as it may have KV.
            #     #
            #     # We run the proposer once per lookahead slot. In the future we
            #     # should delegate how many times it runs to the proposer.
            #     for _ in range(max(num_lookahead_slots, 1)):
            #         self.proposer_worker.execute_model()
            # else:
            #     if not data["no_spec"]:
            #         self.proposer_worker.sampler_output(None, None, None)

            if issubclass(type(self.proposer_worker), NonLLMProposerWorkerBase):
                if not data["no_spec"]:
                    self.proposer_worker.sampler_output(None, num_lookahead_slots, None)
            else:
734
735
736
737
738
739
740
                # Even if num_lookahead_slots is zero, we want to run the
                # proposer model as it may have KV.
                #
                # We run the proposer once per lookahead slot. In the future we
                # should delegate how many times it runs to the proposer.
                for _ in range(max(num_lookahead_slots, 1)):
                    self.proposer_worker.execute_model()
741
742
743

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
744
745
            if data["run_spec_proposer_for_prefill"]:
                self.proposer_worker.execute_model()
746

747
        return True
748

749
750
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
751
752
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
753
754
755
756
757
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

758
759
760
761
        When `enable_chunked_prefill` is set, scorer will batch decodes and 
        prefills, while proposer will sync its KV-cache by running an extra
        forward on prefills.

762
763
764
        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
765
766
        # With prefill chunking, expect requests to have prompts first
        # so that backend gets prefill|decode.
767
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
768

769
770
771
772
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

773
774
775
776
        # Pass last logits from target model to proposer
        execute_model_req.previous_logits = self.previous_logits
        self.previous_logits = None

777
778
779
        execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
        self.kvcache_slot_to_be_moved = None

780
781
782
783
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
784

785
786
787
788
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")
789
790
        
        # Pass tree attention mask and postions to target model
791
        if self.tree_decoding:
792
793
            execute_model_req.tree_attn_masks = proposals.tree_attn_masks
            execute_model_req.tree_position_ids = proposals.tree_position_ids
794

795
796
        execute_model_req.previous_hidden_states = None

797
798
799
800
801
802
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
            execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
        # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
        # discard decodes that have already been processed by proposer.
        non_spec_indices = [
            idx for idx in non_spec_indices
            if execute_model_req.seq_group_metadata_list[idx].is_prompt
        ]
        if len(non_spec_indices):
            all_hidden_states = proposal_scores.hidden_states
            # TODO fix `return_hidden_states`, same as in `_run_no_spec`
            if all_hidden_states is not None:
                prefill_hidden_states = all_hidden_states[non_spec_indices]
                execute_model_req.previous_hidden_states = \
                    prepare_prefill_hidden_states(prefill_hidden_states)
            # Sync proposer KV cache for prefills.
            prefill_req = execute_model_req.clone(non_spec_seqs)
            self.proposer_worker.execute_model(prefill_req)

822
        with Timer() as verification_timer:
823
            accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
824
825
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)
826
827
            
            # move kv_caches of selected tokens to right positions
828
            if self.tree_decoding:
829
                self.move_caches(execute_model_req, select_indices_list, accept_lengths)
830
831
832
833

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
834

835
        return self._create_output_sampler_list(
836
            execute_model_req.seq_group_metadata_list,
837
838
            accepted_token_ids,
            target_logprobs=target_logprobs,
839
840
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
841
842
843
844
845
846
847
848

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
849
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
850
851
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
852
853
854

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
855
856
857
858
859
860
861
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
862
863
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
864
865
        original_indices = spec_indices + non_spec_indices

866
        # Get probabilities of target model, including bonus tokens.
867
868
869
870
        if non_spec_indices:
            proposal_verifier_probs = proposal_scores.probs[spec_indices]
        else:
            proposal_verifier_probs = proposal_scores.probs
871

872
        if self.tree_decoding:
873
874
875
            retrieve_indices = proposals.retrieve_indices
            proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]

876
        # Get non-speculative sampled tokens from target model.
877
878
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

879
        # Get bonus tokens from target model.
880
881
882
        bonus_token_ids = proposal_scores.token_ids[:, -1:]
        if non_spec_indices:
            bonus_token_ids = bonus_token_ids[spec_indices, :]
883
884

        # Get probabilities according to proposal method.
885
        proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
886
        if proposal_probs is not None and non_spec_indices:
887
            proposal_probs = proposal_probs[spec_indices]
888
889

        # Get proposed tokens.
890
891
892
        proposal_token_ids = proposals.proposal_token_ids
        if non_spec_indices:
            proposal_token_ids = proposal_token_ids[spec_indices] 
893

894
        # Get tree buffers.
895
        cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
896
        if cart_candidates is not None and non_spec_indices:
897
            cart_candidates = cart_candidates[spec_indices] 
898

899
        # Sampler arguments
900
901
902
903
904
905
906
907
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
908

909
910
911
912
913
914
915
916
        if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
            sampler_extra_kwargs["cart_candidates"] = cart_candidates
            sampler_extra_kwargs["best_candidates"] = []
            sampler_extra_kwargs["accept_lengths"] = []

            first_step_flags = []
            for i, sgm in enumerate(seq_group_metadata_list):
                seq = next(iter(sgm.seq_data.values()))
917
                first_step_flags.append(True if seq.get_first_step_flag() else False)
918
919
920
            
            sampler_extra_kwargs["first_step_flags"] = first_step_flags

921
        accepted_token_ids = self.spec_decode_sampler(
922
            target_with_bonus_probs=proposal_verifier_probs,
923
924
925
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
926
            **sampler_extra_kwargs,
927
928
929
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
930
        if not self.tree_decoding:
931
932
933
934
935
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                        1).clone()
        else:
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()

936
937
938
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
939
        logprobs = proposal_scores.logprobs
940
941
942
943
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

944
        hidden_states = proposal_scores.hidden_states
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972

        select_indices = None
        accept_lengths = None

        select_indices_list = []

        if cart_candidates is None:
            if hidden_states is not None:
                # Contract hidden states based on accepted tokens
                hs_size = hidden_states.shape[-1]

                accepted_index = accepted_token_ids + 1  # Convert -1 to 0
                accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
                index = accepted_index[:, None, None].expand(-1, 1, hs_size)
                second_last_token_hidden_states = hidden_states[:, -2]  # b x d
                hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
                # Store hidden states from target model for subsequent decode step
                self.previous_hidden_states = HiddenStates(
                    hidden_states, seq_group_metadata_list,
                    second_last_token_hidden_states)   
        else:
            retrieve_indices = proposals.retrieve_indices

            batch_size = len(seq_group_metadata_list)

            best_candidates = sampler_extra_kwargs["best_candidates"]
            accept_lengths = sampler_extra_kwargs["accept_lengths"]

973
            # Contract hidden states based on accepted tokens
974
            hs_size = hidden_states.shape[-1]
975
976
977
978
979
980
981
982
983
984
            hidden_states = hidden_states.view(batch_size, -1, hs_size)

            # Store logits from target model for subsequent proposal
            logits = proposal_scores.logits
            logits = logits.view(batch_size, -1, logits.shape[-1])
            logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]

            previous_logits_list = []

            previous_hidden_state_list = []
985
986

            retrieve_indices = retrieve_indices.cpu()
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
            
            for i in range(batch_size):
                logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
                previous_logits_list.append(logit)
                select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
                hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
                select_indices_list.append(select_indices)
                previous_hidden_state_list.append(hidden_state)

            logits = torch.cat(previous_logits_list, dim=0)
            self.previous_logits = Logits(logits, seq_group_metadata_list)

            hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
            self.previous_hidden_states = HiddenStates(hidden_states, 
                                                       seq_group_metadata_list,)

        return accepted_token_ids, logprobs, select_indices_list, accept_lengths
zhuwenwen's avatar
zhuwenwen committed
1004

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    
    def move_caches(self, execute_model_req: ExecuteModelRequest, 
                    select_indices_list: List[torch.Tensor], 
                    accept_lengths: List[int]):
        """Given selected output tokens and accept length,
        move kv_caches of selected tokens to right positions.
        """
        seq_lens = []
        for sg in execute_model_req.seq_group_metadata_list:
            seq_ids = list(sg.seq_data.keys())
            
            for seq_id in seq_ids:
                seq_data = sg.seq_data[seq_id]
                seq_len = seq_data.get_len()
                token_chunk_size = sg.token_chunk_size
                context_len = seq_len - 1
                seq_len = min(seq_len, context_len + token_chunk_size)

                # first step of tree-style decoding need to ignore first generated token
1024
                if seq_data.get_first_step_flag():
1025
                    seq_len -= 1
1026
1027
1028

                # move cache is the last step of tree decoding, so set first_step_flag to false
                seq_data.set_first_step_flag(False)   
1029
1030
1031
1032
                seq_lens.append(seq_len)

        model_input = self.scorer._scorer_worker.model_input
        block_tables = None
1033
1034
        if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
            block_tables = model_input.attn_metadata.block_tables_list
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050

        if block_tables is None:
            raise RuntimeError("Can not get block_tables from model_input.")

        cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
        block_size = cache_engine.block_size
        batch_size = len(select_indices_list)
        block_table_stride = len(block_tables) // batch_size

        select_indices_slot_mapping = []
        target_slot_mapping = []
        for i in range(batch_size):
            accept_legth = accept_lengths[i]

            if accept_legth > 0:
                select_indices = select_indices_list[i][1:] + seq_lens[i]
1051
                select_indices = select_indices.tolist()
1052
1053
1054
1055
                self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
                                            select_indices, block_size, block_tables)

                target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
1056
                target_indices = target_indices.tolist()
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
                self.compute_slot_mapping(target_slot_mapping, i*block_table_stride, 
                                            target_indices, block_size, block_tables)

        if len(select_indices_slot_mapping) >0:
            select_indices_slot_tensor = torch.tensor(select_indices_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            target_slot_mapping_tensor = torch.tensor(target_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
1068

1069
            self.kvcache_slot_to_be_moved = src_dst_tensor
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    
    def compute_slot_mapping(self, slot_mapping: List[int],
                         seq_id: int, select_indices: List[int], block_size: int,
                         block_tables: List[List[int]]):
        """
        Compute slot mapping.
        """
        # Mask the [0, start_idx) tokens of the prompt with
        # PAD_SLOT_ID, where start_idx is max(0, seq_len -
        # sliding_window). For example, if the prompt len is 10,
        # sliding window is 8, and block size is 4, the first two
        # tokens are masked and the slot mapping will be
        # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
        block_table = block_tables[seq_id]
        for index in select_indices:
            block_number = block_table[index // block_size]
            block_offset = index % block_size
            slot = block_number * block_size + block_offset
            slot_mapping.append(slot)
1089

1090
1091
1092
1093
    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
1094
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
1095
        k: int,
1096
        stage_times: Tuple[float, float, float],
1097
1098
1099
1100
1101
1102
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
1103
1104
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
1125
1126
1127

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
1128
1129
1130
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

1131
1132
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

1133
        # Serialize tensor to CPU Python list.
1134
1135
1136
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
1137
1138
        # Non-terminal prefill chunks will end up here as rows with just -1s
        # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]]
1139
        sampler_output_list: List[SamplerOutput] = []
1140
1141
1142
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
1143
1144
                break

1145
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
1146
1147
1148
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
1149
                step_output_token_ids.append(
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
1162
1163
1164
1165
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

1166
1167
1168
1169
1170
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
1171
1172
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
1173
        if maybe_rejsample_metrics is not None and sampler_output_list:
1174
1175
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
1176
1177
1178
1179
1180

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)
1181
1182
        return sampler_output_list

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1348
1349
1350
1351
    @property
    def _driver_rank(self) -> int:
        return 0

1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1362
    def start_profile(self):
1363
        if isinstance(self.scorer_worker, WorkerBase):
1364
1365
1366
            self.scorer_worker.start_profile()

    def stop_profile(self):
1367
        if isinstance(self.scorer_worker, WorkerBase):
1368
1369
            self.scorer_worker.stop_profile()

1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None