spec_decode_worker.py 69.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import copy
5
from collections import defaultdict
6
from functools import cached_property
7
from typing import Any, Dict, List, Optional, Set, Tuple, Type
8
9

import torch
10
import torch.nn as nn
11

12
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
13
from vllm.distributed.communication_op import broadcast_tensor_dict
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
16
from vllm.model_executor.layers.sampler import SamplerOutput
17
from vllm.model_executor.layers.spec_decode_base_sampler import (
18
    SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
19
20
from vllm.model_executor.layers.typical_acceptance_sampler import (
    TypicalAcceptanceSampler)
21
from vllm.platforms import current_platform
22
23
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, ExecuteModelRequest,
24
                           HiddenStates, SequenceGroupMetadata,
25
                           get_all_seq_ids_and_request_ids, Logits)
zhuwenwen's avatar
zhuwenwen committed
26
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
27
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
28
29
30
31

if current_platform.is_cuda_alike():
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

32
33
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
34
from vllm.spec_decode.medusa_worker import MedusaWorker
35
from vllm.spec_decode.metrics import AsyncMetricsCollector
36
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
37
from vllm.spec_decode.mqa_scorer import MQAScorer
38
from vllm.spec_decode.multi_step_worker import MultiStepWorker
39
from vllm.spec_decode.ngram_worker import NGramWorker
40
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
41
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
42
from vllm.spec_decode.target_model_runner import TargetModelRunner
43
44
from vllm.spec_decode.util import (Timer, create_logprobs_output,
                                   create_sequence_group_output,
45
                                   get_all_num_logprobs,
46
                                   get_sampled_token_logprobs, nvtx_range,
47
                                   split_batch_by_proposal_len)
48
49
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
50

51
52
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
王敏's avatar
王敏 committed
53
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
54
55

logger = init_logger(__name__)
56
57


58
59
60
61
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
    """Helper method that is the entrypoint for Executors which use
    WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
    """
62
63
    vllm_config: VllmConfig = kwargs.get("vllm_config")
    speculative_config: SpeculativeConfig = vllm_config.speculative_config
64
65
    assert speculative_config is not None

66
67
68
69
    if vllm_config.parallel_config.pipeline_parallel_size > 1:
        raise NotImplementedError("Speculative decoding is currently "
                                  "incompatible with pipeline parallelism")

70
71
72
    draft_worker_kwargs = kwargs.copy()

    kwargs["model_runner_cls"] = TargetModelRunner
73
74
75
    target_worker_config = copy.deepcopy(vllm_config)
    target_worker_config.parallel_config.worker_cls =\
        target_worker_config.parallel_config.sd_worker_cls
76
77
78
    cls = resolve_obj_by_qualname(
        target_worker_config.parallel_config.worker_cls)
    target_worker = cls(*args, **kwargs)
79
80
81
82
    # Set the disable_logprobs variable in the TargetModelRunner instance
    # as per its value specified in the SpeculativeConfig.
    target_worker.model_runner.disable_logprobs =\
         speculative_config.disable_logprobs
83

84
85
    draft_worker_config = copy.deepcopy(vllm_config)
    draft_worker_config.model_config = speculative_config.draft_model_config
86
87
88
89
    draft_worker_config.quant_config = VllmConfig._get_quantization_config(
        draft_worker_config.model_config,
        vllm_config.load_config,
    )
90
91
    speculative_config.draft_parallel_config.worker_cls =\
        draft_worker_config.parallel_config.sd_worker_cls
92
93
94
    draft_worker_config.parallel_config = speculative_config.draft_parallel_config  # noqa
    # TODO allow draft-model specific load config.

95
96
    # Override draft-model specific worker args.
    draft_worker_kwargs.update(
97
        vllm_config=draft_worker_config,
98
99
100
101
102
103
104
        ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
        ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
    )

    spec_decode_worker = SpecDecodeWorker.create_worker(
        scorer_worker=target_worker,
        draft_worker_kwargs=draft_worker_kwargs,
105
        disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
106
107
        disable_by_batch_size=speculative_config.
        speculative_disable_by_batch_size,
108
109
110
111
112
        draft_token_acceptance_method=speculative_config.
        draft_token_acceptance_method,
        typical_acceptance_sampler_posterior_threshold=speculative_config.
        typical_acceptance_sampler_posterior_threshold,
        typical_acceptance_sampler_posterior_alpha=speculative_config.
113
        typical_acceptance_sampler_posterior_alpha,
114
        disable_logprobs=speculative_config.disable_logprobs,
115
        disable_log_stats=speculative_config.disable_log_stats)
116
117
118
119

    return spec_decode_worker


120
# Reminder: Please update docs/source/features/compatibility_matrix.md
121
# If the feature combo become valid
122
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
123
124
125
126
127
128
129
130
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

131
    See https://github.com/vllm-project/vllm/pull/2188 and
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

148
    @classmethod
149
150
    def create_worker(
        cls,
151
        scorer_worker: WorkerBase,
152
        draft_worker_kwargs: Dict[str, Any],
153
        disable_mqa_scorer: bool,
154
        disable_by_batch_size: Optional[int],
155
156
157
        draft_token_acceptance_method: str,
        typical_acceptance_sampler_posterior_threshold: float,
        typical_acceptance_sampler_posterior_alpha: float,
158
        disable_logprobs: bool,
159
        disable_log_stats: bool,
160
161
    ) -> "SpecDecodeWorker":

162
        allow_zero_draft_token_step = True
163
164
165
166
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
167
168
169
        draft_model_config = draft_worker_kwargs["vllm_config"].model_config
        draft_parallel_config: ParallelConfig = draft_worker_kwargs[
            'vllm_config'].parallel_config
170
        if ngram_prompt_lookup_max > 0:
王敏's avatar
王敏 committed
171
172
173
            draft_parallel_config: ParallelConfig = draft_worker_kwargs[
                'parallel_config']
            assert draft_parallel_config.tensor_parallel_size == 1
174
175
            draft_worker_kwargs[
                "device_type"] = scorer_worker.device_config.device.type
176
177
178
179
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
180
181
182
            draft_tp = draft_parallel_config.tensor_parallel_size
            target_tp = scorer_worker.parallel_config.tensor_parallel_size

183
            if draft_model_config.hf_config.model_type == "mlp_speculator":
184
                proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
185
            elif draft_model_config.hf_config.model_type == "medusa":
186
                proposer_worker = MedusaWorker(**draft_worker_kwargs)
187
188
            else:
                if draft_tp == 1:
189
190
191
                    if current_platform.is_cuda_alike():
                        draft_worker_kwargs[
                            "model_runner_cls"] = TP1DraftModelRunner
192
                else:
193
                    if draft_model_config.hf_config.model_type == "eagle":
194
195
196
                        raise NotImplementedError(
                            "EAGLE does not support TP > 1 yet")

197
                    allow_zero_draft_token_step = False
198
199
                proposer_worker = MultiStepWorker(**draft_worker_kwargs)

200
201
            proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
                proposer_worker, draft_tp, target_tp)
202

203
204
205
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

206
207
        spec_decode_sampler: SpecDecodeBaseSampler = None
        if draft_token_acceptance_method == "rejection_sampler":
208
            spec_decode_sampler = RejectionSampler()
209
210
211
212
213
214
        elif draft_token_acceptance_method == "typical_acceptance_sampler":
            spec_decode_sampler = TypicalAcceptanceSampler(
                posterior_threshold=\
                    typical_acceptance_sampler_posterior_threshold,
                posterior_alpha=typical_acceptance_sampler_posterior_alpha,
            )
215
216
217
218
219
220
        logger.info(
            "[Speculative Decoding] Configuring"
            " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))

        if not disable_mqa_scorer:
            if scorer_worker.model_runner.attn_backend.get_name(
221
            ) != "FLASH_ATTN":
222
223
224
225
226
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "MQA is only available with flash attn backend.")

227
228
            if draft_model_config and \
                draft_model_config.max_model_len < \
229
230
231
232
233
234
235
236
237
238
239
240
                    scorer_worker.model_config.max_model_len:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "draft model max_model_len is smaller than the target "
                    "model max_model_len.")

            if not scorer_worker.model_runner.model_config.enforce_eager:
                disable_mqa_scorer = True
                logger.info(
                    "[Speculative Decoding] Disabling MQA scorer as the "
                    "target model is not running in eager mode.")
241

242
243
244
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
245
            disable_mqa_scorer=disable_mqa_scorer,
246
            disable_logprobs=disable_logprobs,
247
            disable_log_stats=disable_log_stats,
248
249
            disable_by_batch_size=disable_by_batch_size,
            spec_decode_sampler=spec_decode_sampler,
250
            allow_zero_draft_token_step=allow_zero_draft_token_step)
251

252
253
    def __init__(
        self,
254
        proposer_worker: ProposerWorkerBase,
255
        scorer_worker: WorkerBase,
256
        spec_decode_sampler: SpecDecodeBaseSampler,
257
        disable_mqa_scorer: bool = False,
258
259
        disable_logprobs: bool = False,
        disable_log_stats: bool = False,
260
        metrics_collector: Optional[AsyncMetricsCollector] = None,
261
        disable_by_batch_size: Optional[int] = None,
262
        allow_zero_draft_token_step: Optional[bool] = True,
263
264
265
266
267
268
269
270
271
272
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
273
274
275
276
277
278
            spec_decode_sampler: A Torch module used to perform acceptance
                sampling of the draft tokens in the verification step of
                speculative decoding. Currently we support two different 
                types of sampler namely RejectionSampler and
                TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
                instance of RejectionSampler or TypicalAcceptanceSampler.
279
280
            disable_mqa_scorer: If set to True, disable the MQA scorer and use
                the BatchExpansionTop1Scorer instead.
281
282
283
            disable_logprobs: If set to True, token log probabilities will
                not be output in both the draft worker and the target worker.
                If set to False, log probabilities will be output by both.
284
285
            disable_log_stats: If set to True, disable periodic printing of
                speculative stage times.
286
287
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
288
289
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
290
291
292
            allow_zero_draft_token_step: whether to allow a step where the draft
                model generates no draft token; should disallow when the tp of
                draft model is larger than 1 (TODO: #5814)
293
294
295
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
296
297
298
        scorer_runner = getattr(self.scorer_worker, "model_runner", None)
        self.generators = scorer_runner.get_generators(
        ) if scorer_runner else None
299
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
300
        self.spec_decode_sampler = spec_decode_sampler
301
        self._allow_zero_draft_token_step = allow_zero_draft_token_step
302
        self._metrics = AsyncMetricsCollector(
303
            self.spec_decode_sampler
304
        ) if metrics_collector is None else metrics_collector
305
306
307
308
309
310
311
312
313
        # Tracks the sequence IDs that received a bonus token ID in
        # their last forward pass. Needed only if KV cache is being
        # used for token generation such as in the case of MultiStepWorker.
        self._seq_with_bonus_token_in_last_step: Set[int] = set()
        # Tracks the currently active request ids and the sequence IDs
        # corresponding to them
        self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set)
        # Tracks if the proposer worker uses the KV cache or not.

314
315
        self.probs_dtype = self.spec_decode_sampler.probs_dtype
        self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
316
        # Lazy initialization.
317
        self.scorer: BatchExpansionTop1Scorer
318
        self.disable_mqa_scorer = disable_mqa_scorer
319

320
321
322
        # Hidden states from target model to pass to proposer
        # in the subsequent step.
        self.previous_hidden_states: Optional[HiddenStates] = None
323
        self.previous_logits: Optional[Logits] = None
324
        self.kvcache_slot_to_be_moved: Optional[torch.Tensor] = None
325
        self._disable_logprobs = disable_logprobs
326
        self._disable_log_stats = disable_log_stats
327

328
        self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
329

330
    def init_device(self) -> None:
331
332
333
334
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
335
336
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
337

338
339
340
341
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

342
343
344
        self._metrics.init_tensors(self.rank, device_type=self.device)
        self.spec_decode_sampler.init_tensors(self.rank,
                                              device_type=self.device)
zhuwenwen's avatar
zhuwenwen committed
345
        
346
347
348
349
350
351
352
353
354
        scorer_cls: Type[SpeculativeScorer]
        if self.disable_mqa_scorer:
            scorer_cls = BatchExpansionTop1Scorer
            logger.info("[Speculative Decoding] Use batch "
                        "expansion for scoring proposals.")
        else:
            scorer_cls = MQAScorer
            logger.info(
                "[Speculative Decoding] Use MQA scorer for scoring proposals.")
355

356
        if not self.tree_decoding:
zhuwenwen's avatar
zhuwenwen committed
357
            self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
358
359
                                 device=self.device,
                                 vocab_size=self._vocab_size)
360
361
362
363
364
        else:
            self.scorer = BatchExpansionTreeStyleScorer(
                scorer_worker=self.scorer_worker,
                device=self.device,
                vocab_size=self._vocab_size)
365

366
367
        self._configure_model_sampler_for_spec_decode()

368
369
370
    def load_model(self, *args, **kwargs):
        pass

371
372
373
    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
374
        which significantly reduces overhead of sampling during verification.
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
392
393
        
        # tree_style decoding modify probs in _verify_tokens
394
        if not self.tree_decoding:
395
396
            (self.scorer_worker.model_runner.model.sampler.
            should_modify_greedy_probs_inplace) = True
397
        self.proposer_worker.set_include_gpu_probs_tensor()
398
        self.proposer_worker.set_should_modify_greedy_probs_inplace()
399

400
    def determine_num_available_blocks(self) -> Tuple[int, int]:
401
402
403
404
405
406
407
408
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
409
            self.scorer_worker.determine_num_available_blocks())
410

411
        scorer_cache_block_size_bytes = (
412
            self.scorer_worker.get_cache_block_size_bytes())
413
        proposer_cache_block_size_bytes = (
414
            self.proposer_worker.get_cache_block_size_bytes())
415
416
417
418
419
420

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

421
422
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
423
424
        """Initialize the cache engine of the scorer and proposer workers.
        """
425
426
427
428
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
429

430
431
432
    def get_model(self) -> nn.Module:
        return self.scorer_worker.get_model()

433
434
    @torch.inference_mode()
    def execute_model(
435
436
437
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[SamplerOutput]:
438
439
        """Perform speculative decoding on the input batch.
        """
440
441
442
        if self.rank != self._driver_rank:
            self._run_non_driver_rank()
            return []
443

444
445
446
447
448
449
450
        if execute_model_req is None:
            # This signals that there's no more requests to process for now.
            # All workers are running infinite loop with broadcast_tensor_dict,
            # and it stops the loop when the driver broadcasts an empty input.
            # Send an empty input to notify all other workers to stop their
            # execution loop.
            broadcast_tensor_dict({}, src=0)
451
            return []
452

453
        self._track_finished_requests(execute_model_req)
454
455
456
        disable_all_speculation = self._should_disable_all_speculation(
            execute_model_req)
        num_lookahead_slots = execute_model_req.num_lookahead_slots
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        all_prompt = True
        atleast_one_prompt = False
        all_zero_spec_tokens = True
        for sgm in execute_model_req.seq_group_metadata_list:
            all_prompt = all_prompt and sgm.is_prompt
            atleast_one_prompt = atleast_one_prompt or sgm.is_prompt
            all_zero_spec_tokens = all_zero_spec_tokens and (
                sgm.num_speculative_tokens == 0)

        if all_prompt and execute_model_req.seq_group_metadata_list:
            assert num_lookahead_slots == 0, (
                "Prompt only runs should have num_lookahead_slots equal to 0. "
                "This should never happen, please file a bug at "
                "https://github.com/vllm-project/vllm/issues")
471
472
473
474
475
        # Speculative decoding is disabled in the following cases:
        # 1. Prefill phase: Speculative decoding is not
        #    used during the prefill phase.
        # 2. Auto-disable enabled: The running queue size exceeds
        #    the specified threshold.
476
477
        # 3. No request: There are no requests in the batch, or
        #    none of the requests in the batch have spec decoding enabled.
478
479
        # In any of these cases, the proposer and scorer workers
        # are called normally.
480
        # We expect `num_speculative_tokens` to be None for prefills.
481
482
        no_spec = (num_lookahead_slots == 0 or disable_all_speculation
                   or all_zero_spec_tokens)
483

484
485
486
487
488
        # Broadcast how many lookahead slots are scheduled for this step, and
        # whether all speculation is disabled, to all non-driver workers.

        # This is required as if the number of draft model runs changes
        # dynamically, the non-driver workers won't know unless we perform a
489
        # communication to inform them.
490
491
492
493
494
495
496

        # no_spec is used to signal non-driver worker about prefill vs decode
        # stage. This is needed to ensure that order of execution of proposer
        # and scorer is same in both driver and non-driver workers (i.e.,
        # scorer -> proposer for prefill and proposer -> scorer in decode). This
        # order is needed to support models like EAGLE that take scorer states
        # as inputs.
497
498
        broadcast_dict = dict(
            num_lookahead_slots=num_lookahead_slots,
499
            no_spec=no_spec,
500
            disable_all_speculation=disable_all_speculation,
501
502
503
504
505
506
507
508
509
            # When both chunked prefill and speculative decoding are enabled
            # it is possible that the same batch contains both prefill
            # and decodes. If that happens in the scorer we run the batch
            # as one single forward pass. However, in the proposer we
            # run them as 2 different batches - one for prefill and
            # the other for decodes. The variable indicates to the non-driver
            # worker that there are prefills as part of the speculative batch
            # and hence it needs to run an extra prefill forward pass.
            run_spec_proposer_for_prefill=atleast_one_prompt,
510
511
512
513
514
515
516
517
518
        )
        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

        assert execute_model_req.seq_group_metadata_list is not None, (
            "speculative decoding requires non-None seq_group_metadata_list")

        self._maybe_disable_speculative_tokens(
            disable_all_speculation, execute_model_req.seq_group_metadata_list)

519
        if no_spec:
520
521
522
523
524
525
526
527
528
529
530
531
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all_speculation)
        return self._run_speculative_decoding_step(execute_model_req,
                                                   num_lookahead_slots)

    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop to perform speculative decoding
        in parallel worker."""
        while self._run_non_driver_rank():
            pass

532
533
    def _should_disable_all_speculation(
            self, execute_model_req: ExecuteModelRequest) -> bool:
534
535
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
536
537
        return (execute_model_req.running_queue_size
                >= self.disable_by_batch_size)
538
539
540
541
542
543
544
545
546
547
548
549
550
551

    def _maybe_disable_speculative_tokens(
            self, disable_all_speculation: bool,
            seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
        if not disable_all_speculation:
            return

        for seq_group_metadata in seq_group_metadata_list:
            # Once num_speculative_tokens is set to 0, the spec decode
            # of this request will be disabled forever.
            # TODO(comaniac): We currently store spec decoding specific
            # state in the global data structure, but we should maintain
            # this state within spec decode worker.
            seq_group_metadata.num_speculative_tokens = 0
552

553
554
    def _serialize_sampler_output_no_logprobs(
            self, execute_model_req: ExecuteModelRequest,
555
            sampler_output: SamplerOutput) -> List[SamplerOutput]:
556
        """
557
558
        Creates and returns a `SamplerOutput` with only the token IDs being
        serialized to CPU and populated in `CompletionSequenceGroupOutput`.
559
560
561
562
563
564
565
566
567
568
569
        All other parameters in `CompletionSequenceGroupOutput` related to log 
        probabilities are skipped.

        Args:
            execute_model_req (ExecuteModelRequest): The model request that
            was executed.
            sampler_output (SamplerOutput): The output from the sampler with
            only GPU tensors populated.

        Returns:
            SamplerOutput: A new `SamplerOutput` instance containing a list of 
570
571
            `CompletionSequenceGroupOutput` objects with only token IDs
            populated.
572
        """
573
574
575
576
577
578
579
580
581
582
583
584
        seq_output_prompt_logprobs = [
            seq.is_prompt and seq.sampling_params.prompt_logprobs is not None
            and seq.sampling_params.prompt_logprobs > 0
            for seq in execute_model_req.seq_group_metadata_list
        ]
        # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID
        sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where(
            # subtracting is faster than testing for equality
            sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \
            if any(seq_output_prompt_logprobs) else \
                sampler_output.sampled_token_ids).tolist()

585
        seq_data_entries = [
586
587
588
            (seq_id, seq_data) for sg in \
            execute_model_req.seq_group_metadata_list \
            for seq_id, seq_data in sg.seq_data.items()
589
        ]
590
591
        completion_seq_group_output_list: List[
            CompletionSequenceGroupOutput] = []
592
593
594
        output_index = 0
        # Make sure the non-terminal prefill chunks are still aligned with
        # their own empty output.
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        for idx, seq_group_meta in enumerate(
                execute_model_req.seq_group_metadata_list):
            needs_prompt_logprobs = seq_output_prompt_logprobs[idx]
            seq_id, seq_data = seq_data_entries[idx]
            if needs_prompt_logprobs:
                prompt_token_ids = seq_data.get_prompt_token_ids()

                # Some of these sequences may belong to non-terminal chunks,
                # which may still have to report logprobs for prompts.
                start = 1 if seq_data._num_computed_tokens == 0 \
                    else seq_data._num_computed_tokens
                end = (seq_data._num_computed_tokens + \
                       seq_group_meta.token_chunk_size)
                prompt_token_ids = prompt_token_ids[start:end]
                prompt_logprobs = [
                    create_logprobs_output(
                        token_id=p_token_id,
612
613
614
615
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
                        topk_token_ids=[],
                        topk_logprobs=[],
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
                    ) for p_token_id in prompt_token_ids
                ]
            else:
                prompt_logprobs = None

            # Since we can get chunks here, we dont always have a sampled token
            # (only on last chunk) but we still have to provide an output.
            if not seq_group_meta.do_sample:
                completion_seq_group_output_list.append(
                    CompletionSequenceGroupOutput(
                        samples=[], prompt_logprobs=prompt_logprobs))
                continue

            # Sequence with output.
            completion_seq_group_output_list.append(
                create_sequence_group_output(
                    token_id=sampled_token_ids_list[output_index][0],
                    token_id_logprob_rank=-1,
                    token_id_logprob=0.0,
                    seq_id=seq_id,
                    topk_token_ids=[],
                    topk_logprobs=[],
                    prompt_logprobs=prompt_logprobs))
            output_index += 1
640
641

        return [SamplerOutput(outputs=completion_seq_group_output_list)]
642

643
    @nvtx_range("spec_decode_worker._run_no_spec")
644
645
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
646
647
        """Run a single generation step without any speculation. The input is
        sent to the proposer and scorer model so that the KV cache is consistent
648
649
650
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
651
        """
652
        if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
王敏's avatar
王敏 committed
653
654
            execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
            self.kvcache_slot_to_be_moved = None
655

656
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
657
658
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
659

660
        # Store hidden states from target model execution, BxD.
661
662
        hidden_states = sampler_output.hidden_states
        if hidden_states is not None:
663
664
665
666
667
668
669
            # Only decodes and prefill terminal chunks need a hidden state.
            seq_group_meta_with_hidden = [
                sg for sg in execute_model_req.seq_group_metadata_list
                if sg.do_sample
            ]
            if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
                # Drop hidden_states with no prediction (eg non-terminal chunks)
670
671
672
                hidden_states = hidden_states[
                    torch.where(sampler_output.sampled_token_ids -
                                VLLM_INVALID_TOKEN_ID)[0]]
673
674
            if self.previous_hidden_states is None and len(
                    seq_group_meta_with_hidden):
675
                self.previous_hidden_states = HiddenStates(
676
677
678
679
680
                    hidden_states, seq_group_meta_with_hidden)
            elif self.previous_hidden_states and len(
                    seq_group_meta_with_hidden):
                self.previous_hidden_states.update(hidden_states,
                                                   seq_group_meta_with_hidden)
681
                
zhuwenwen's avatar
zhuwenwen committed
682
683
684
685
686
687
688
689
690
691
            # Store logits from target model execution.
            if self.tree_decoding:
                logits = sampler_output.logits
                if logits is not None:
                    if self.previous_logits is None:
                        self.previous_logits = Logits(
                            logits, execute_model_req.seq_group_metadata_list)
                    else:
                        self.previous_logits.update(
                            logits, execute_model_req.seq_group_metadata_list)
692
693
694
695
696
697
698
699
700
701

        if not skip_proposer:
            # We prepare the prefill hidden states here so that there no
            # additional complexity in worker for spec_decode vs non_spec_decode
            # flow and execute_model doesn't need additional modifications.
            execute_model_req.previous_hidden_states = \
                prepare_prefill_hidden_states(
                    sampler_output.prefill_hidden_states)

            self.proposer_worker.execute_model(execute_model_req)
702

703
704
705
        sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
            execute_model_req=execute_model_req, sampler_output=sampler_output)
                                    if self._disable_logprobs else
706
                                    [sampler_output])
707

708
709
        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
710
711
        sampler_output.sampled_token_probs = None
        sampler_output.sampled_token_ids = None
712
        sampler_output.logprobs = None
713
        return sampler_output_to_return
714

715
    def _run_non_driver_rank(self) -> bool:
716
717
718
        """Run proposer and verifier model in non-driver workers. This is used
        for both speculation cases (num_lookahead_slots>0) and non-speculation
        cases (e.g. prefill).
719

720
        Returns True if there are remaining sequences to process.
721
        """
722
723
724
725
726
727
        assert self.rank != self._driver_rank

        data = broadcast_tensor_dict(src=self._driver_rank)
        if not data:
            return False
        num_lookahead_slots = data["num_lookahead_slots"]
728

729
730
731
732
733
734
        # In case of prefill, scorer_worker has to be run before proposer so
        # that the hidden states can be propagated to proposer when needed.
        if data["no_spec"]:
            self.scorer_worker.execute_model()

        if not data["disable_all_speculation"]:
735
            # if not self.tree_decoding:
王敏's avatar
王敏 committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
            #     # Even if num_lookahead_slots is zero, we want to run the
            #     # proposer model as it may have KV.
            #     #
            #     # We run the proposer once per lookahead slot. In the future we
            #     # should delegate how many times it runs to the proposer.
            #     for _ in range(max(num_lookahead_slots, 1)):
            #         self.proposer_worker.execute_model()
            # else:
            #     if not data["no_spec"]:
            #         self.proposer_worker.sampler_output(None, None, None)

            if issubclass(type(self.proposer_worker), NonLLMProposerWorkerBase):
                if not data["no_spec"]:
                    self.proposer_worker.sampler_output(None, num_lookahead_slots, None)
            else:
751
752
753
754
755
756
757
                # Even if num_lookahead_slots is zero, we want to run the
                # proposer model as it may have KV.
                #
                # We run the proposer once per lookahead slot. In the future we
                # should delegate how many times it runs to the proposer.
                for _ in range(max(num_lookahead_slots, 1)):
                    self.proposer_worker.execute_model()
758
759
760

        if not data["no_spec"]:
            self.scorer_worker.execute_model()
761
762
            if data["run_spec_proposer_for_prefill"]:
                self.proposer_worker.execute_model()
763

764
        return True
765

766
767
    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
768
769
            self, execute_model_req: ExecuteModelRequest,
            num_lookahead_slots: int) -> List[SamplerOutput]:
770
771
772
773
774
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

775
776
777
778
        When `enable_chunked_prefill` is set, scorer will batch decodes and 
        prefills, while proposer will sync its KV-cache by running an extra
        forward on prefills.

779
780
781
        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """
782
783
        # With prefill chunking, expect requests to have prompts first
        # so that backend gets prefill|decode.
784
        assert num_lookahead_slots == execute_model_req.num_lookahead_slots
785

786
787
788
789
        # Pass last hidden states from target model to proposer
        execute_model_req.previous_hidden_states = self.previous_hidden_states
        self.previous_hidden_states = None

790
791
792
793
        # Pass last logits from target model to proposer
        execute_model_req.previous_logits = self.previous_logits
        self.previous_logits = None

794
795
796
        execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
        self.kvcache_slot_to_be_moved = None

797
798
799
800
        with Timer() as proposal_timer:
            # Generate proposals using draft worker.
            proposals = self.proposer_worker.get_spec_proposals(
                execute_model_req, self._seq_with_bonus_token_in_last_step)
801

802
803
804
805
        if not self._allow_zero_draft_token_step and proposals.no_proposals:
            #TODO: Fix it #5814
            raise RuntimeError("Cannot handle cases where distributed draft "
                               "workers generate no tokens")
806
807
        
        # Pass tree attention mask and postions to target model
808
        if self.tree_decoding:
809
810
            execute_model_req.tree_attn_masks = proposals.tree_attn_masks
            execute_model_req.tree_position_ids = proposals.tree_position_ids
811

812
813
        execute_model_req.previous_hidden_states = None

814
815
816
817
818
819
        with Timer() as scoring_timer:
            proposal_scores = self.scorer.score_proposals(
                execute_model_req,
                proposals,
            )

820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len(
            execute_model_req.seq_group_metadata_list, proposals.proposal_lens)
        # With prefill chunking enabled, `non_spec_seqs` contains prefills too:
        # discard decodes that have already been processed by proposer.
        non_spec_indices = [
            idx for idx in non_spec_indices
            if execute_model_req.seq_group_metadata_list[idx].is_prompt
        ]
        if len(non_spec_indices):
            all_hidden_states = proposal_scores.hidden_states
            if all_hidden_states is not None:
                prefill_hidden_states = all_hidden_states[non_spec_indices]
                execute_model_req.previous_hidden_states = \
                    prepare_prefill_hidden_states(prefill_hidden_states)
            # Sync proposer KV cache for prefills.
            prefill_req = execute_model_req.clone(non_spec_seqs)
836
            # TODO avoid sampling here?
837
838
            self.proposer_worker.execute_model(prefill_req)

839
        with Timer() as verification_timer:
840
            accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
841
842
                execute_model_req.seq_group_metadata_list, proposal_scores,
                proposals, execute_model_req.num_lookahead_slots)
843
844
            
            # move kv_caches of selected tokens to right positions
845
            if self.tree_decoding:
846
                self.move_caches(execute_model_req, select_indices_list, accept_lengths)
847
848
849
850

        stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
                       scoring_timer.elapsed_time_ms,
                       verification_timer.elapsed_time_ms)
851

852
        return self._create_output_sampler_list(
853
            execute_model_req.seq_group_metadata_list,
854
855
            accepted_token_ids,
            target_logprobs=target_logprobs,
856
857
            prompt_logprobs=proposal_scores.prompt_logprobs
            if not self._disable_logprobs else None,
858
859
            k=execute_model_req.num_lookahead_slots,
            stage_times=stage_times)
860
861
862
863
864
865
866
867

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
868
    ) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
869
870
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
871
872
873

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
874
875
876
877
878
879
880
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
881
882
        (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
            seq_group_metadata_list, proposal_lens_list)
883
884
        original_indices = spec_indices + non_spec_indices

885
        # Get probabilities of target model, including bonus tokens.
886
887
888
889
        if non_spec_indices:
            proposal_verifier_probs = proposal_scores.probs[spec_indices]
        else:
            proposal_verifier_probs = proposal_scores.probs
890

891
        if self.tree_decoding:
892
893
894
            retrieve_indices = proposals.retrieve_indices
            proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]

895
        # Get non-speculative sampled tokens from target model.
896
897
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

898
        # Get bonus tokens from target model.
899
900
901
        bonus_token_ids = proposal_scores.token_ids[:, -1:]
        if non_spec_indices:
            bonus_token_ids = bonus_token_ids[spec_indices, :]
902
903

        # Get probabilities according to proposal method.
904
        proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
905
        if proposal_probs is not None and non_spec_indices:
906
            proposal_probs = proposal_probs[spec_indices]
907
908

        # Get proposed tokens.
909
910
911
        proposal_token_ids = proposals.proposal_token_ids
        if non_spec_indices:
            proposal_token_ids = proposal_token_ids[spec_indices] 
912

913
        # Get tree buffers.
914
        cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
915
        if cart_candidates is not None and non_spec_indices:
916
            cart_candidates = cart_candidates[spec_indices] 
917

918
        # Sampler arguments
919
920
921
922
923
924
925
926
        sampler_extra_kwargs: Dict[str, Any] = {}
        if self.generators and isinstance(self.spec_decode_sampler,
                                          SpecDecodeStochasticBaseSampler):
            sampler_extra_kwargs["seeded_seqs"] = {
                idx: self.generators[sgm.request_id]
                for idx, sgm in enumerate(seq_group_metadata_list)
                if sgm.sampling_params.seed is not None
            }
927

928
929
930
931
932
933
934
935
        if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
            sampler_extra_kwargs["cart_candidates"] = cart_candidates
            sampler_extra_kwargs["best_candidates"] = []
            sampler_extra_kwargs["accept_lengths"] = []

            first_step_flags = []
            for i, sgm in enumerate(seq_group_metadata_list):
                seq = next(iter(sgm.seq_data.values()))
936
                first_step_flags.append(True if seq.get_first_step_flag() else False)
937
938
939
            
            sampler_extra_kwargs["first_step_flags"] = first_step_flags

940
        accepted_token_ids = self.spec_decode_sampler(
941
            target_with_bonus_probs=proposal_verifier_probs,
942
943
944
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
945
            **sampler_extra_kwargs,
946
947
948
        )
        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
949
        if not self.tree_decoding:
950
951
952
953
954
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                        1).clone()
        else:
            non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()

955
956
957
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
958
        logprobs = proposal_scores.logprobs
959
960
961
962
        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

963
        # B x K+1 x D
964
        hidden_states = proposal_scores.hidden_states
965
966
967
968
969
970
971
972

        select_indices = None
        accept_lengths = None

        select_indices_list = []

        if cart_candidates is None:
            if hidden_states is not None:
zhuwenwen's avatar
zhuwenwen committed
973
974
975
976
                # Only get terminal hidden states for next step
                terminal_metadata = [
                    sg for sg in seq_group_metadata_list if sg.do_sample
                ]
977
978
979
                # Contract hidden states based on accepted tokens
                hs_size = hidden_states.shape[-1]
                accepted_index = accepted_token_ids + 1  # Convert -1 to 0
zhuwenwen's avatar
zhuwenwen committed
980
981
982
983
984
985
986
987
988
989
                accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)  # b
                # Drop non-terminal prefill chunks hidden states.
                hidden_states = hidden_states[accepted_index !=
                                            VLLM_INVALID_TOKEN_ID]
                accepted_index = accepted_index[accepted_index !=
                                                VLLM_INVALID_TOKEN_ID]
                assert len(accepted_index) == hidden_states.shape[0] == len(
                    terminal_metadata)
                index = accepted_index[:, None, None].expand(-1, 1,
                                                            hs_size)  # b x 1 x d
990
991
                second_last_token_hidden_states = hidden_states[:, -2]  # b x d
                hidden_states = hidden_states.gather(1, index).squeeze(1)  # b x d
zhuwenwen's avatar
zhuwenwen committed
992
                    
993
994
                # Store hidden states from target model for subsequent decode step
                self.previous_hidden_states = HiddenStates(
zhuwenwen's avatar
zhuwenwen committed
995
996
                    hidden_states, terminal_metadata,
                    second_last_token_hidden_states)  
997
998
999
1000
1001
1002
1003
1004
        else:
            retrieve_indices = proposals.retrieve_indices

            batch_size = len(seq_group_metadata_list)

            best_candidates = sampler_extra_kwargs["best_candidates"]
            accept_lengths = sampler_extra_kwargs["accept_lengths"]

1005
            # Contract hidden states based on accepted tokens
1006
            hs_size = hidden_states.shape[-1]
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            hidden_states = hidden_states.view(batch_size, -1, hs_size)

            # Store logits from target model for subsequent proposal
            logits = proposal_scores.logits
            logits = logits.view(batch_size, -1, logits.shape[-1])
            logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]

            previous_logits_list = []

            previous_hidden_state_list = []
1017
1018

            retrieve_indices = retrieve_indices.cpu()
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
            
            for i in range(batch_size):
                logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
                previous_logits_list.append(logit)
                select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
                hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
                select_indices_list.append(select_indices)
                previous_hidden_state_list.append(hidden_state)

            logits = torch.cat(previous_logits_list, dim=0)
            self.previous_logits = Logits(logits, seq_group_metadata_list)

            hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
            self.previous_hidden_states = HiddenStates(hidden_states, 
                                                       seq_group_metadata_list,)

        return accepted_token_ids, logprobs, select_indices_list, accept_lengths
zhuwenwen's avatar
zhuwenwen committed
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    
    def move_caches(self, execute_model_req: ExecuteModelRequest, 
                    select_indices_list: List[torch.Tensor], 
                    accept_lengths: List[int]):
        """Given selected output tokens and accept length,
        move kv_caches of selected tokens to right positions.
        """
        seq_lens = []
        for sg in execute_model_req.seq_group_metadata_list:
            seq_ids = list(sg.seq_data.keys())
            
            for seq_id in seq_ids:
                seq_data = sg.seq_data[seq_id]
                seq_len = seq_data.get_len()
                token_chunk_size = sg.token_chunk_size
                context_len = seq_len - 1
                seq_len = min(seq_len, context_len + token_chunk_size)

                # first step of tree-style decoding need to ignore first generated token
1056
                if seq_data.get_first_step_flag():
1057
                    seq_len -= 1
1058
1059
1060

                # move cache is the last step of tree decoding, so set first_step_flag to false
                seq_data.set_first_step_flag(False)   
1061
1062
1063
1064
                seq_lens.append(seq_len)

        model_input = self.scorer._scorer_worker.model_input
        block_tables = None
1065
1066
        if hasattr(model_input, 'attn_metadata') and hasattr(model_input.attn_metadata, 'block_tables_list'):
            block_tables = model_input.attn_metadata.block_tables_list
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082

        if block_tables is None:
            raise RuntimeError("Can not get block_tables from model_input.")

        cache_engine = self.scorer._scorer_worker.cache_engines[execute_model_req.virtual_engine]
        block_size = cache_engine.block_size
        batch_size = len(select_indices_list)
        block_table_stride = len(block_tables) // batch_size

        select_indices_slot_mapping = []
        target_slot_mapping = []
        for i in range(batch_size):
            accept_legth = accept_lengths[i]

            if accept_legth > 0:
                select_indices = select_indices_list[i][1:] + seq_lens[i]
1083
                select_indices = select_indices.tolist()
1084
1085
1086
1087
                self.compute_slot_mapping(select_indices_slot_mapping, i*block_table_stride,
                                            select_indices, block_size, block_tables)

                target_indices = torch.arange(accept_legth+1)[1:] + seq_lens[i]
1088
                target_indices = target_indices.tolist()
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
                self.compute_slot_mapping(target_slot_mapping, i*block_table_stride, 
                                            target_indices, block_size, block_tables)

        if len(select_indices_slot_mapping) >0:
            select_indices_slot_tensor = torch.tensor(select_indices_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            target_slot_mapping_tensor = torch.tensor(target_slot_mapping,
                                            dtype=torch.long,
                                            device=self.device).view(-1, 1)
            src_dst_tensor = torch.cat([select_indices_slot_tensor, target_slot_mapping_tensor], dim=-1) #[batch_size*T, 2]
1100

1101
            self.kvcache_slot_to_be_moved = src_dst_tensor
zhuwenwen's avatar
zhuwenwen committed
1102

1103

1104
1105
1106
1107
    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
1108
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
1109
1110
        prompt_logprobs: Optional[
            torch.Tensor],  # shape: [nprompt_tokens, vocab_size]
1111
        k: int,
1112
        stage_times: Tuple[float, float, float],
1113
1114
1115
1116
1117
1118
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
1119
1120
        batch_size, num_steps = accepted_token_ids.shape
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
        if self._disable_logprobs:
            # We are skipping the logprobs. Hence don't serialize the
            # logprobs related tensors from the GPU. Instead create
            # empty/dummy lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
            self._create_dummy_logprob_lists(
                batch_size, num_steps,
                self.scorer_worker.model_config.max_logprobs)
        else:
            # Organize input tensors by step instead of by sequence.
            target_logprobs_by_step = target_logprobs.transpose(0, 1)
            # Serialize all tensors into Python lists.
            (accepted_token_id_ranks_by_step,
            accepted_token_id_logprobs_by_step,
            topk_logprobs_by_step, topk_indices_by_step) =\
                self._create_logprob_lists_from_tensors(
                    target_logprobs_by_step, accepted_token_ids_by_step,
                    self.scorer_worker.model_config.max_logprobs)
1141
1142
1143

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
1144
1145
1146
        seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
            seq_group_metadata_list)

1147
1148
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

1149
        # Serialize tensor to CPU Python list.
1150
1151
1152
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
1153
        # Non-terminal prefill chunks will end up here as rows with just -1s
1154
1155
        # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
        # terminal chunks will only have one generated token at time 0.
1156
        sampler_output_list: List[SamplerOutput] = []
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222

        # Prefills are not multi-step (return at most 1 token), in order to
        # avoid padding or repetition to fit decodes, we separate them.
        for i, sg in enumerate(seq_group_metadata_list):
            if not sg.is_prompt:
                # Requests are ordered as prefills|decodes=>no more prefills.
                break
            num_logprobs = num_logprobs_per_seq[i]
            seq_kwargs = dict(token_id=-1,
                              token_id_logprob_rank=0,
                              token_id_logprob=-float('inf'),
                              topk_token_ids=[-1] * num_logprobs,
                              topk_logprobs=[-float('inf')] * num_logprobs,
                              seq_id=seq_ids[i])
            # Terminal chunk, has token.
            if sg.do_sample:
                seq_kwargs.update(
                    dict(
                        token_id=accepted_token_ids[i][0].item(),
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            0][i],
                        token_id_logprob=accepted_token_id_logprobs_by_step[0]
                        [i],
                        topk_token_ids=topk_indices_by_step[0][i]
                        [:num_logprobs],
                        # output only so step is 0
                        topk_logprobs=topk_logprobs_by_step[0][i]
                        [:num_logprobs],
                    ))
            needs_plogs = (sg.sampling_params.prompt_logprobs
                           and sg.sampling_params.prompt_logprobs > 0)
            plogs = None
            if prompt_logprobs is not None:
                # Even non-terminal prompt chunks can have logprobs here.
                plogs = prompt_logprobs[i]
            elif needs_plogs:
                # Prompt logprobs are requested but `_disable_logprobs` is set.
                seq_data = next(iter(sg.seq_data.values()))
                # Get only the tokens in this chunk!
                prompt_token_ids = seq_data.get_prompt_token_ids()
                prompt_token_ids = prompt_token_ids[
                    seq_data.
                    _num_computed_tokens:seq_data._num_computed_tokens +
                    sg.token_chunk_size]

                is_first_chunk = seq_data._num_computed_tokens == 0
                # There's no prob generated for the first token in a sequence.
                if is_first_chunk:
                    prompt_token_ids = prompt_token_ids[1:]
                plogs = [
                    create_logprobs_output(
                        token_id=p_token_id,
                        token_id_logprob_rank=-1,
                        token_id_logprob=0.0,
                        topk_token_ids=[],
                        topk_logprobs=[],
                    ) for p_token_id in prompt_token_ids
                ]
            seq_kwargs.update(dict(prompt_logprobs=plogs))

            sampler_output_list.append(
                SamplerOutput(
                    outputs=[create_sequence_group_output(
                        **seq_kwargs)]))  # type: ignore

        # Decodes, create one SamplerOutput per-step (at most K+1).
1223
        for step_index in range(num_steps):
1224
1225
1226
1227
            if all(token_id == -1 for sg, token_id in zip(
                    seq_group_metadata_list,
                    accepted_token_ids_by_step[step_index])
                   if not sg.is_prompt):
1228
1229
                break

1230
            step_output_token_ids: List[CompletionSequenceGroupOutput] = []
1231
            for sequence_index in range(batch_size):
1232
1233
1234
1235
1236
                seq_meta = seq_group_metadata_list[sequence_index]
                # Prompts already processed above.
                if seq_meta.is_prompt:
                    continue

1237
1238
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]
1239
                step_output_token_ids.append(
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
1252
1253
1254
1255
                    ))
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

1256
1257
1258
1259
1260
        # Populate the data structures needed to keep track of sequences with
        # bonus tokens.
        self._track_sequences_with_bonus_tokens(seq_ids,
                                                request_ids_seq_ids_mapping,
                                                accepted_token_ids_by_step)
1261
1262
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
1263
        if maybe_rejsample_metrics is not None and sampler_output_list:
1264
1265
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics
1266
1267
1268
1269
1270

            # Log time spent in each stage periodically.
            # This is periodic because the rejection sampler emits metrics
            # periodically.
            self._maybe_log_stage_times(*stage_times)
1271
1272
        # First `n_prefills` entries will contain prefills SamplerOutput when
        # chunked prefill is enabled, the rest is decodes in multi-step format.
1273
1274
        return sampler_output_list

1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
    def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
                               scoring_time_ms: float,
                               verification_time_ms: float) -> None:
        """Log the speculative stage times. If stat logging is disabled, do
        nothing.
        """
        if self._disable_log_stats:
            return

        logger.info(
            "SpecDecodeWorker stage times: "
            "average_time_per_proposal_tok_ms=%.02f "
            "scoring_time_ms=%.02f verification_time_ms=%.02f",
            average_time_per_proposal_tok_ms, scoring_time_ms,
            verification_time_ms)

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
    def _create_dummy_logprob_lists(
        self,
        batch_size: int,
        num_steps: int,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four dummy lists representing token probabilities 
        and their ranks.

        This method initializes and returns:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            batch_size (int): The size of the batch.
            num_steps (int): The number of steps in the sequence.
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing four dummy lists as described above.
        """
        accepted_token_id_ranks_by_step = [[-1] * batch_size
                                           for _ in range(num_steps)]
        accepted_token_id_logprobs_by_step = [[0.0] * batch_size
                                              for _ in range(num_steps)]
        topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        topk_indices_by_step: List[List[List[Optional[int]]]] = [[
            [None] * num_top_k for _ in range(batch_size)
        ] for _ in range(num_steps)]
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

    def _create_logprob_lists_from_tensors(
        self,
        target_logprobs_by_step: torch.Tensor,
        accepted_token_ids_by_step: torch.Tensor,
        num_top_k: int,
    ) -> Tuple[List[List[int]], List[List[float]],
               List[List[List[Optional[float]]]],
               List[List[List[Optional[int]]]]]:
        """
        Creates and returns four lists representing token probabilities and
        their ranks.

        This method initializes and returns four lists containing:
            - The ranks of the accepted tokens, shaped (num_steps, batch_size)
            - The log probabilities of the accepted tokens,
              shaped (num_steps, batch_size)
            - The log probabilities of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)
            - The token IDs of the top k tokens,
              shaped (num_steps, batch_size, num_top_k)

        Args:
            target_logprobs_by_step (torch.Tensor): Tensor representing the
            log probabilities of the target model,
            shaped (num_steps, batch_size, vocab_size)
            accepted_token_ids_by_step (torch.Tensor): Tensor representing
            the accepted  token_ids, shaped (num_steps, batch_size) 
            num_top_k (int): The number of top-k token log probabilities to
            return.
        
        Returns:
            A tuple containing the lists as described above.
        """
        # Serialize all tensors to CPU Python lists.
        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step_tensor,
         accepted_token_id_logprobs_by_step_tensor
         ) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )
        # Get the top-k logprobs (which may or may not include the
        # logprob of the accepted token).
        (topk_logprobs_by_step_tensor,
         topk_indices_by_step_tensor) = target_logprobs_by_step.topk(
             k=num_top_k,
             dim=-1,
         )
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step_tensor.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step_tensor.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist()
        topk_indices_by_step = topk_indices_by_step_tensor.tolist()
        return (accepted_token_id_ranks_by_step,
                accepted_token_id_logprobs_by_step, topk_logprobs_by_step,
                topk_indices_by_step)

1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
    def _track_finished_requests(self, execute_model_req: ExecuteModelRequest):
        """
        Removes the finished requests and their associated sequence ids from
        internal book keeping data structures.
        """
        for finished_request in execute_model_req.finished_requests_ids:
            for seq_id in self._request_id_seq_id_mapping[finished_request]:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            del self._request_id_seq_id_mapping[finished_request]

    def _track_sequences_with_bonus_tokens(
            self, seq_ids: List[int],
            request_ids_seq_ids_mapping: Dict[str, Set[int]],
            accepted_token_ids_by_step: List[List[int]]):
        """
        Updates the internal data structures which keep track of sequences
        which have been assigned bonus tokens in their last forward pass.
        """
        for seq_index, seq_id in enumerate(seq_ids):
            last_token_id = accepted_token_ids_by_step[-1][seq_index]
            if last_token_id == -1:
                self._seq_with_bonus_token_in_last_step.discard(seq_id)
            else:
                self._seq_with_bonus_token_in_last_step.add(seq_id)
        for request_id, sequences in request_ids_seq_ids_mapping.items():
            self._request_id_seq_id_mapping[request_id].update(sequences)

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

1440
1441
1442
1443
    @property
    def _driver_rank(self) -> int:
        return 0

1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

1454
    def start_profile(self):
1455
        if isinstance(self.scorer_worker, WorkerBase):
1456
1457
1458
            self.scorer_worker.start_profile()

    def stop_profile(self):
1459
        if isinstance(self.scorer_worker, WorkerBase):
1460
1461
            self.scorer_worker.stop_profile()

1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495


def prepare_prefill_hidden_states(
        prefill_hidden_states: torch.Tensor) -> HiddenStates:
    # For prefill step in proposer, we run the model for N-1 tokens
    # because Nth token will be processed in the first decode step. For
    # N-1 tokens, the input should be 0:N-1 hidden states which should
    # be concatanated with 1:N token (since output of scorer has to be
    # the input for proposer). Therefore, we shift the hidden states to
    # align n-1th hidden state with nth token.
    return HiddenStates(prefill_hidden_states.roll(
        shifts=1, dims=0)) if prefill_hidden_states is not None else None