spec_decode_worker.py 21.2 KB
Newer Older
1
from functools import cached_property
2
from typing import Any, Dict, List, Optional, Tuple
3
4
5

import torch

6
from vllm.logger import init_logger
7
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
8
9
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
                           SequenceGroupMetadata)
10
11
12
13
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
                                         SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector
14
from vllm.spec_decode.multi_step_worker import MultiStepWorker
15
from vllm.spec_decode.ngram_worker import NGramWorker
16
17
18
from vllm.spec_decode.util import (create_sequence_group_output,
                                   get_all_num_logprobs, get_all_seq_ids,
                                   get_sampled_token_logprobs, nvtx_range,
19
                                   split_batch_by_proposal_len)
20
21
22
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase

logger = init_logger(__name__)
23
24


25
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
26
27
28
29
30
31
32
33
    """Worker which implements speculative decoding.

    Speculative decoding reduces decoding per-token latency by using a proposal
    method, such as a small draft model, to speculate ahead of a larger LLM. The
    probabilities of the speculative tokens are then determined by the larger
    LLM, after which some verification routine determines which (if any) of the
    speculative tokens are accepted by the larger LLM.

34
    See https://github.com/vllm-project/vllm/pull/2188 and
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    https://github.com/vllm-project/vllm/pull/3103 for more info.

    The current implementation has the following limitations:
    * Only draft-model proposal is implemented (contributions for more forms are
        welcome!).
    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
        future work.
    * Only lossless rejection sampling is supported. Contributions adding lossy
        verification routines are welcome (e.g. Medusa's typical acceptance).
    * All sequences in a batch must have the same proposal length, or zero. This
        can be improved by having per-sequence speculation in the future.
    * The scoring forward pass is done without an MQA kernel, which is
        suboptimal especially as the batch size, proposal length, and sequence
        lengths grow. Contributions to add a MQA scoring are welcome once
        correctness tests pass.
        More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
    """

53
    @classmethod
54
55
56
    def create_worker(
        cls,
        scorer_worker: WorkerBase,
57
58
        draft_worker_kwargs: Dict[str, Any],
        disable_by_batch_size: Optional[int],
59
60
    ) -> "SpecDecodeWorker":

61
62
63
64
        ngram_prompt_lookup_max = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
        ngram_prompt_lookup_min = (
            draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
65

66
        disable_bonus_tokens = True
67
        if ngram_prompt_lookup_max > 0:
68
            disable_bonus_tokens = False
69
70
71
72
73
74
            proposer_worker = NGramWorker(**draft_worker_kwargs)
            proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
                                                  ngram_prompt_lookup_max)
        else:
            proposer_worker = MultiStepWorker(**draft_worker_kwargs)

75
76
77
        logger.info("Configuring SpecDecodeWorker with proposer=%s",
                    type(proposer_worker))

78
79
80
        return SpecDecodeWorker(
            proposer_worker,
            scorer_worker,
81
82
83
            disable_by_batch_size=disable_by_batch_size,
            rejection_sampler=RejectionSampler(
                disable_bonus_tokens=disable_bonus_tokens, ))
84

85
86
    def __init__(
        self,
87
        proposer_worker: WorkerBase,
88
        scorer_worker: WorkerBase,
89
90
        rejection_sampler: RejectionSampler,
        metrics_collector: Optional[AsyncMetricsCollector] = None,
91
        disable_by_batch_size: Optional[int] = None,
92
93
94
95
96
97
98
99
100
101
102
103
    ):
        """
        Create a SpecDecodeWorker.

        Args:
            proposer_worker: A worker that can produce speculative tokens for
                sequences.
            scorer_worker: A worker that produces probabilities of speculative
                tokens according to some base model. Typically a vanilla vLLM
                Worker.
            rejection_sampler: A Torch module used to perform modified rejection
                sampling for speculative decoding.
104
105
            disable_by_batch_size: If the batch size is larger than this,
                disable speculative decoding for new incoming requests.
106
107
108
109
110
            metrics_collector: Helper class for collecting metrics; can be set
                for testing purposes.
        """
        self.proposer_worker = proposer_worker
        self.scorer_worker = scorer_worker
111
        self.disable_by_batch_size = disable_by_batch_size or float("inf")
112
113
114
115
116
117
118
119
120
        self.rejection_sampler = rejection_sampler

        self._metrics = AsyncMetricsCollector(
            rejection_sampler
        ) if metrics_collector is None else metrics_collector

        self.probs_dtype = self.rejection_sampler.probs_dtype
        self.token_id_dtype = self.rejection_sampler.token_id_dtype

121
122
        # Lazy initiazliation.
        self.scorer: SpeculativeScorer
123

124
    def init_device(self) -> None:
125
126
127
128
        """Initialize both scorer and proposer models.
        """
        # The scorer worker model is initialized first in case the proposer
        # model has a smaller TP degree than the target worker.
129
130
        self.scorer_worker.init_device()
        self.proposer_worker.init_device()
131

132
133
134
135
        # NOTE(cade): load_model is not part of the WorkerBase interface.
        self.scorer_worker.load_model()
        self.proposer_worker.load_model()

136
137
138
139
140
141
142
        self._metrics.init_gpu_tensors(self.rank)
        self.rejection_sampler.init_gpu_tensors(self.rank)
        self.scorer = BatchExpansionTop1Scorer(
            scorer_worker=self.scorer_worker,
            device=self.device,
            vocab_size=self._vocab_size)

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self._configure_model_sampler_for_spec_decode()

    def _configure_model_sampler_for_spec_decode(self):
        """Configure model sampler to emit GPU tensors. This allows spec decode
        to keep data on device without transferring to CPU and serializing,
        which significantly reduces overhead of rejection sampling.

        NOTE(cade): This breaks abstraction boundaries pretty badly. The better
        design is to have the "move to CPU and serialize" sampling decision be
        done outside of the model/sampler; this way the "last-mile" worker
        object which interfaces with the scheduler can serialize and incur the
        performance hit as necessary. This allows us to run the worker several
        iterations in a row without incurring the "move to CPU and serialize"
        performance penalty.

        Since this requires a large change to vLLM, we defer it to later and
        temporarily accept this broken abstraction boundary.

        NOTE(cade): This will require a special check if the proposer worker
        does not have a sampler (e.g. ngram speculation).
        """
        (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
         ) = True
166
        self.proposer_worker.set_include_gpu_probs_tensor()
167

168
    def determine_num_available_blocks(self) -> Tuple[int, int]:
169
170
171
172
173
174
175
176
        """Determine the number of cache blocks to use.

        This is done by profiling the scorer model (which is typically the
        larger of the two). Then the total memory which would be used by the
        scorer cache is divided evenly between the proposer and scorer model KV,
        such that the number of blocks is equal in both KV caches.
        """
        num_gpu_blocks, num_cpu_blocks = (
177
            self.scorer_worker.determine_num_available_blocks())
178

179
        scorer_cache_block_size_bytes = (
180
            self.scorer_worker.get_cache_block_size_bytes())
181
        proposer_cache_block_size_bytes = (
182
            self.proposer_worker.get_cache_block_size_bytes())
183
184
185
186
187
188

        new_num_gpu_blocks = split_num_cache_blocks_evenly(
            scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
            num_gpu_blocks)
        return new_num_gpu_blocks, num_cpu_blocks

189
190
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
191
192
        """Initialize the cache engine of the scorer and proposer workers.
        """
193
194
195
196
        self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                            num_cpu_blocks=num_cpu_blocks)
        self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                              num_cpu_blocks=num_cpu_blocks)
197
198
199

    @torch.inference_mode()
    def execute_model(
200
201
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
202
203
204
        """Perform speculative decoding on the input batch.
        """

205
        assert execute_model_req.seq_group_metadata_list is not None, (
206
207
208
            "speculative decoding "
            "requires non-None seq_group_metadata_list")

209
210
211
212
213
214
215
216
217
218
219
220
221
        # When the batch size is too large, disable speculative decoding
        # to stop trading off throughput for latency.
        disable_all = (execute_model_req.running_queue_size >=
                       self.disable_by_batch_size)
        if disable_all:
            for seq_group_metadata in execute_model_req.seq_group_metadata_list:
                # Once num_speculative_tokens is set to 0, the spec decode
                # of this request will be disabled forever.
                # TODO(comaniac): We currently store spec decoding specific
                # state in the global data structure, but we should maintain
                # this state within spec decode worker.
                seq_group_metadata.num_speculative_tokens = 0

222
        # If no spec tokens, call the proposer and scorer workers normally.
223
224
        # This happens for prefill, or when the spec decode is disabled
        # for this batch.
225
226
        if execute_model_req.num_lookahead_slots == 0 or len(
                execute_model_req.seq_group_metadata_list) == 0:
227
228
            return self._run_no_spec(execute_model_req,
                                     skip_proposer=disable_all)
229
230

        return self._run_speculative_decoding_step(execute_model_req)
231
232

    @nvtx_range("spec_decode_worker._run_no_spec")
233
234
235
236
237
238
239
    def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
                     skip_proposer: bool) -> List[SamplerOutput]:
        """Run a prefill step, without any speculation. The input is sent to
        the proposer and scorer model so that the KV cache is consistent
        between the two. When skip_proposer is True, the proposer model is
        not called, meaning that the kv-cache in proposer for requests is not
        updated, so they cannot enable spec decode in the rest decoding.
240
        """
241
242
        if not skip_proposer:
            self.proposer_worker.execute_model(execute_model_req)
243

244
        sampler_output = self.scorer_worker.execute_model(execute_model_req)
245
246
        assert len(sampler_output) == 1
        sampler_output = sampler_output[0]
247
248
249
250
251

        # Clear device tensors from sampler output. This reduces communication
        # overhead when the engine runs in a different process than the workers.
        sampler_output.probs = None
        sampler_output.sampled_tokens = None
252
        sampler_output.logprobs = None
253
254
255
256
        return [sampler_output]

    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
    def _run_speculative_decoding_step(
257
258
            self,
            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
259
260
261
262
263
264
265
266
267
268
        """Execute a single step of speculative decoding.

        This invokes the proposer worker to get k speculative tokens for each
        sequence, then scores each speculative token using the scoring worker.

        Returns a list of SamplerOutput, each containing a single token per
        sequence.
        """

        # Generate proposals using draft worker.
269
        proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
270
271

        proposal_scores = self.scorer.score_proposals(
272
            execute_model_req,
273
274
275
            proposals,
        )

276
        accepted_token_ids, target_logprobs = self._verify_tokens(
277
278
            execute_model_req.seq_group_metadata_list, proposal_scores,
            proposals, execute_model_req.num_lookahead_slots)
279

280
        return self._create_output_sampler_list(
281
            execute_model_req.seq_group_metadata_list,
282
283
            accepted_token_ids,
            target_logprobs=target_logprobs,
284
            k=execute_model_req.num_lookahead_slots)
285
286
287
288
289
290
291
292

    @nvtx_range("spec_decode_worker._verify_tokens")
    def _verify_tokens(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        proposal_scores: SpeculativeScores,
        proposals: SpeculativeProposals,
        max_proposal_len: int,
293
    ) -> Tuple[torch.Tensor, torch.Tensor]:
294
295
        """Determine which speculative tokens are accepted using the
        probabilities of each token according to the proposer and scorer models.
296
297
298

        Returns a tuple of Tensors, one for the accepted token ids and one for
        the logprobs according to the scoring model.
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        """
        proposal_lens_list = proposals.proposal_lens.tolist()

        # vLLM currently only supports proposal lens equal to zero or the batch
        # proposal len. This adds some complexity (splitting the batch into spec
        # and non spec sequences) and should be removed in the future. It can be
        # done by supporting per-sequence proposal lens.
        _, spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=False)
        _, non_spec_indices = split_batch_by_proposal_len(
            seq_group_metadata_list,
            proposal_lens_list,
            select_proposal_len_zero=True)
        original_indices = spec_indices + non_spec_indices

316
317
318
319
        # Get probabilities of target model, excluding bonus token.
        proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]

        # Get non-speculative sampled tokens from target model.
320
321
        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]

322
323
324
325
326
327
328
329
330
        # Get bonus tokens from target model.
        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]

        # Get probabilities according to proposal method.
        proposal_probs = proposals.proposal_probs[spec_indices]

        # Get proposed tokens.
        proposal_token_ids = proposals.proposal_token_ids[spec_indices]

331
        accepted_token_ids = self.rejection_sampler(
332
333
334
335
            target_probs=proposal_verifier_probs,
            bonus_token_ids=bonus_token_ids,
            draft_probs=proposal_probs,
            draft_token_ids=proposal_token_ids,
336
337
338
339
340
341
342
343
344
        )

        # Append output tokens from non-speculative sequences to
        # the accepted token ids tensor.
        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
                                                       1).clone()
        non_spec_token_ids[:, 1:] = -1
        accepted_token_ids = torch.cat(
            [accepted_token_ids, non_spec_token_ids])
345
        logprobs = proposal_scores.logprobs
346
347
348
349
350

        # Rearrange so that results are in the order of the original seq group
        # metadata.
        accepted_token_ids[original_indices] = accepted_token_ids.clone()

351
        return accepted_token_ids, logprobs
352
353
354
355
356

    def _create_output_sampler_list(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
357
        target_logprobs: torch.Tensor,  # shape: [batch_size, k+1, vocab_size]
358
359
360
361
362
363
364
        k: int,
    ) -> List[SamplerOutput]:
        """Given the accepted token ids, create a list of SamplerOutput.

        The output is padded with -1 tokens such that each sequence has
        the same number of outputs.
        """
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        batch_size, num_steps = accepted_token_ids.shape

        # Organize input tensors by step instead of by sequence.
        target_logprobs_by_step = target_logprobs.transpose(0, 1)
        accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)

        # Get the logprobs/rank of the accepted tokens.
        (accepted_token_id_ranks_by_step,
         accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs(
             logprob_tensor=target_logprobs_by_step,
             sampled_token_ids=accepted_token_ids_by_step,
         )

        # Get the top-k logprobs (which may or may not include the logprob of
        # the accepted token).
        (topk_logprobs_by_step,
         topk_indices_by_step) = target_logprobs_by_step.topk(
             k=self.scorer_worker.model_config.max_logprobs,
             dim=-1,
         )

        # Get the sequence ids and num_logprobs (sampling parameter) in the
        # batch.
388
        seq_ids = get_all_seq_ids(seq_group_metadata_list)
389
390
391
392
393
394
395
396
397
398
399
400
        num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)

        # Serialize all tensors to CPU Python lists.
        accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
        accepted_token_id_ranks_by_step = (
            accepted_token_id_ranks_by_step.tolist())
        accepted_token_id_logprobs_by_step = (
            accepted_token_id_logprobs_by_step.tolist())
        topk_logprobs_by_step = topk_logprobs_by_step.tolist()
        topk_indices_by_step = topk_indices_by_step.tolist()

        # Construct the output on a per-step, per-sequence basis.
401
        sampler_output_list = []
402
403
404
        for step_index in range(num_steps):
            if all(token_id == -1
                   for token_id in accepted_token_ids_by_step[step_index]):
405
406
407
                break

            step_output_token_ids = []
408
409
410
411
            for sequence_index in range(batch_size):
                # Each sequence may have a different num_logprobs; retrieve it.
                num_logprobs = num_logprobs_per_seq[sequence_index]

412
                step_output_token_ids.append(
413
414
415
416
417
418
419
420
421
422
423
424
                    create_sequence_group_output(
                        token_id=accepted_token_ids_by_step[step_index]
                        [sequence_index],
                        token_id_logprob_rank=accepted_token_id_ranks_by_step[
                            step_index][sequence_index],
                        token_id_logprob=accepted_token_id_logprobs_by_step[
                            step_index][sequence_index],
                        seq_id=seq_ids[sequence_index],
                        topk_token_ids=topk_indices_by_step[step_index]
                        [sequence_index][:num_logprobs],
                        topk_logprobs=topk_logprobs_by_step[step_index]
                        [sequence_index][:num_logprobs],
425
                    ))
426

427
428
429
            sampler_output_list.append(
                SamplerOutput(outputs=step_output_token_ids))

430
431
        maybe_rejsample_metrics = (
            self._metrics.maybe_collect_rejsample_metrics(k))
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
        if maybe_rejsample_metrics is not None:
            sampler_output_list[
                0].spec_decode_worker_metrics = maybe_rejsample_metrics

        return sampler_output_list

    @cached_property
    def _vocab_size(self) -> int:
        """Get the vocab size of the model and make sure it's consistent between
        draft and target workers.
        """
        vocab_sizes = [
            worker.vocab_size
            for worker in [self.proposer_worker, self.scorer_worker]
        ]
        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
        return vocab_sizes[0]

    @property
    def rank(self):
        return self.scorer_worker.rank

    @property
    def device(self):
        return self.scorer_worker.device

458
459
460
461
462
463
464
465
466
467
    def get_cache_block_size_bytes(self):
        """Return the size of a cache block in bytes.
        
        This function is only used to compose workers within a SpecDecodeWorker.
        We leave composing a SpecDecodeWorker within a SpecDecodeWorker
        undefined for now, although it could be implemented in the future.
        See https://arxiv.org/abs/2308.04623.
        """
        raise NotImplementedError

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
                                  proposer_cache_block_size_bytes: int,
                                  total_num_gpu_blocks: int) -> int:
    """Given total_num_gpu_blocks, the number of GPU blocks that could be
    allocate to the target model, this function calculates how many blocks
    should be given to the draft and target model.

    Note that usually the block size, in bytes, of each model is different,
    as it's a function of number of KV/layer, number of heads, and hidden
    dimension size.

    Since the target and draft models allocate the same number of blocks, we
    simply calculate the number of blocks where if allocated by both models,
    the total memory usage from KV cache is no larger than the number of
    blocks allocatable by the target model alone.
    """
    new_num_gpu_blocks = int(
        total_num_gpu_blocks * scorer_cache_block_size_bytes /
        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))

    return new_num_gpu_blocks