Unverified Commit 62b8aebc authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951)

parent 050f285f
......@@ -6,8 +6,7 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import (maybe_mock_device_tensors,
sampler_output_to_torch)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker import Worker
......@@ -329,12 +328,15 @@ class DraftModelTop1Proposer(SpeculativeProposer):
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty tensors.
proposal_tokens = torch.zeros(0,
# In this case we return empty proposals.
proposal_tokens = torch.full(size=(
batch_size,
max_proposal_len,
),
fill_value=-1,
dtype=torch.long,
device=self._device)
proposal_probs = torch.zeros(0,
proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
......@@ -345,17 +347,6 @@ class DraftModelTop1Proposer(SpeculativeProposer):
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for step_output in sampler_output:
maybe_mock_device_tensors(
sampler_output=step_output,
batch_size=len(proposal_lens),
vocab_size=self._vocab_size,
device=self._device,
)
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
......
......@@ -111,6 +111,32 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
which significantly reduces overhead of rejection sampling.
NOTE(cade): This breaks abstraction boundaries pretty badly. The better
design is to have the "move to CPU and serialize" sampling decision be
done outside of the model/sampler; this way the "last-mile" worker
object which interfaces with the scheduler can serialize and incur the
performance hit as necessary. This allows us to run the worker several
iterations in a row without incurring the "move to CPU and serialize"
performance penalty.
Since this requires a large change to vLLM, we defer it to later and
temporarily accept this broken abstraction boundary.
NOTE(cade): This will require a special check if the proposer worker
does not have a sampler (e.g. ngram speculation).
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.proposer_worker.model_runner.model.sampler.
include_gpu_probs_tensor) = True
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
......@@ -286,15 +312,26 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
proposal_probs = proposal_scores.probs[spec_indices, :-1]
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities of target model, excluding bonus token.
proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
accepted_token_ids = self.rejection_sampler(
proposal_probs,
bonus_token_ids,
proposals.proposal_probs,
proposals.proposal_token_ids,
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
)
# Append output tokens from non-speculative sequences to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment