spec_decode_worker.py 21.5 KB
Newer Older
1
from functools import cached_property
2
from typing import Dict, List, Optional, Tuple
3
4
5

import torch

6
from vllm.logger import init_logger
7
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
8
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
9
10
11
12
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
13
from vllm.spec_decode.multi_step_worker import MultiStepWorker
14
from vllm.spec_decode.ngram_worker import NGramWorker
15
16
17
from vllm.spec_decode.util import (create_sequence_group_output,
                                   get_all_num_logprobs, get_all_seq_ids,
                                   get_sampled_token_logprobs, nvtx_range,
18
                                   split_batch_by_proposal_len)
19
20
21
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
22
23


24
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
25
26
27
28
29
30
31
32
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

33
    See https://github.com/vllm-project/vllm/pull/2188 and
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * Only lossless rejection sampling is supported. Contributions adding lossy
        verification routines are welcome (e.g. Medusa's typical acceptance).
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

52
    @classmethod
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def create_worker(
        cls,
        scorer_worker: WorkerBase,
        draft_worker_kwargs,
    ) -> "SpecDecodeWorker":

        if "ngram_prompt_lookup_max" in draft_worker_kwargs:
            ngram_prompt_lookup_max = (
                draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
            ngram_prompt_lookup_min = (
                draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
        else:
            ngram_prompt_lookup_max = 0

        if ngram_prompt_lookup_max > 0:
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
            proposer_worker = MultiStepWorker(**draft_worker_kwargs)

74
75
76
77
78
79
80
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
            # TODO(cade) disable strict mode for speedup.
            rejection_sampler=RejectionSampler(strict_mode=True),
        )

81
82
    def __init__(
        self,
83
        proposer_worker: WorkerBase,
84
        scorer_worker: WorkerBase,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        rejection_sampler: RejectionSampler,
        metrics_collector: Optional[AsyncMetricsCollector] = None,
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
            rejection_sampler: A Torch module used to perform modified rejection
                sampling for speculative decoding.
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
        self.rejection_sampler = rejection_sampler

        self._metrics = AsyncMetricsCollector(
            rejection_sampler
        ) if metrics_collector is None else metrics_collector

        self.probs_dtype = self.rejection_sampler.probs_dtype
        self.token_id_dtype = self.rejection_sampler.token_id_dtype

113
114
        # Lazy initiazliation.
        self.scorer: SpeculativeScorer
115

116
    def init_device(self) -> None:
117
118
119
120
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
121
122
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
123

124
125
126
127
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

128
129
130
131
132
133
134
        self._metrics.init_gpu_tensors(self.rank)
        self.rejection_sampler.init_gpu_tensors(self.rank)
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        self._configure_model_sampler_for_spec_decode()

    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
        which significantly reduces overhead of rejection sampling.

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
158
        self.proposer_worker.set_include_gpu_probs_tensor()
159

160
    def determine_num_available_blocks(self) -> Tuple[int, int]:
161
162
163
164
165
166
167
168
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
169
            self.scorer_worker.determine_num_available_blocks())
170

171
        scorer_cache_block_size_bytes = (
172
            self.scorer_worker.get_cache_block_size_bytes())
173
        proposer_cache_block_size_bytes = (
174
            self.proposer_worker.get_cache_block_size_bytes())
175
176
177
178
179
180

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

181
182
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
183
184
        """Initialize the cache engine of the scorer and proposer workers.
        """
185
186
187
188
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
189
190
191
192
193
194
195
196

    @torch.inference_mode()
    def execute_model(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Optional[Dict[int, int]],
        blocks_to_swap_out: Optional[Dict[int, int]],
        blocks_to_copy: Optional[Dict[int, List[int]]],
197
        num_lookahead_slots: int,
198
199
200
201
202
203
204
205
    ) -> List[SamplerOutput]:
        """Perform speculative decoding on the input batch.
        """

        assert seq_group_metadata_list is not None, (
            "speculative decoding "
            "requires non-None seq_group_metadata_list")

206
207
        #logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
        #            num_lookahead_slots)
208

209
210
        # If no spec tokens, call the proposer and scorer workers normally.
        # Used for prefill.
211
        if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
212
213
214
215
216
217
218
219
220
221
222
223
            return self._run_no_spec(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=blocks_to_swap_in,
                blocks_to_swap_out=blocks_to_swap_out,
                blocks_to_copy=blocks_to_copy,
            )

        return self._run_speculative_decoding_step(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
224
            k=num_lookahead_slots,
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        )

    @nvtx_range("spec_decode_worker._run_no_spec")
    def _run_no_spec(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Optional[Dict[int, int]],
        blocks_to_swap_out: Optional[Dict[int, int]],
        blocks_to_copy: Optional[Dict[int, List[int]]],
    ) -> List[SamplerOutput]:
        """Run a prefill step, without any speculation. The input is sent to the
        proposer and scorer model so that the KV cache is consistent between the
        two.
        """
239
        #logger.info("run proposer worker no spec")
240
241
242
243
244
245

        self.proposer_worker.execute_model(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
246
        )
247

248
        #logger.info("run target worker no spec")
249
250
251
252
253
254
        sampler_output = self.scorer_worker.execute_model(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
        )
255
256
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
257
258
259
260
261

        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
        sampler_output.probs = None
        sampler_output.sampled_tokens = None
262
        sampler_output.logprobs = None
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        return [sampler_output]

    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Optional[Dict[int, int]],
        blocks_to_swap_out: Optional[Dict[int, int]],
        blocks_to_copy: Optional[Dict[int, List[int]]],
        k: int,
    ) -> List[SamplerOutput]:
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """

283
        #logger.info("get spec proposals")
284
        # Generate proposals using draft worker.
285
286
287
        assert blocks_to_swap_in is not None
        assert blocks_to_swap_out is not None
        assert blocks_to_copy is not None
288
289
290
291
        proposals = self.proposer_worker.get_spec_proposals(
            seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
            blocks_to_copy, k)

292
        #logger.info("score proposals")
293
294
295
296
297
298
299
300
301
        proposal_scores = self.scorer.score_proposals(
            seq_group_metadata_list,
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
            k,
            proposals,
        )

302
        #logger.info("verify proposals")
303
304
        accepted_token_ids, target_logprobs = self._verify_tokens(
            seq_group_metadata_list, proposal_scores, proposals, k)
305

306
        #logger.info("create output list")
307
308
309
310
311
        return self._create_output_sampler_list(
            seq_group_metadata_list,
            accepted_token_ids,
            target_logprobs=target_logprobs,
            k=k)
312
313
314
315
316
317
318
319

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
320
    ) -> Tuple[torch.Tensor, torch.Tensor]:
321
322
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
323
324
325

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

343
344
345
346
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
347
348
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

349
350
351
352
353
354
355
356
357
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

358
        accepted_token_ids = self.rejection_sampler(
359
360
361
362
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
363
364
365
366
367
368
369
370
371
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
372
        logprobs = proposal_scores.logprobs
373
374
375
376
377

        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

378
        return accepted_token_ids, logprobs
379
380
381
382
383

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
384
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
385
386
387
388
389
390
391
        k: int,
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        batch_size, num_steps = accepted_token_ids.shape

        # Organize input tensors by step instead of by sequence.
        target_logprobs_by_step = target_logprobs.transpose(0, 1)
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)

        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step,
         accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )

        # Get the top-k logprobs (which may or may not include the logprob of
        # the accepted token).
        (topk_logprobs_by_step,
         topk_indices_by_step) = target_logprobs_by_step.topk(
             k=self.scorer_worker.model_config.max_logprobs,
             dim=-1,
         )

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
415
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
416
417
418
419
420
421
422
423
424
425
426
427
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

        # Serialize all tensors to CPU Python lists.
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step.tolist()
        topk_indices_by_step = topk_indices_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
428
        sampler_output_list = []
429
430
431
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
432
433
434
                break

            step_output_token_ids = []
435
436
437
438
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]

439
                step_output_token_ids.append(
440
441
442
443
444
445
446
447
448
449
450
451
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
452
                    ))
453

454
455
456
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

457
458
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics

        return sampler_output_list

    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

485
486
487
488
489
490
491
492
493
494
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks