Unverified Commit e95cd879 authored by Cade Daniel's avatar Cade Daniel Committed by GitHub
Browse files

[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)

parent 69e1d2fb
......@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list)
......
......@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int = 0) -> SamplerOutput:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
......
......@@ -693,3 +693,16 @@ class SamplerOutput:
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
......@@ -6,10 +6,10 @@ import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
sampler_output_to_torch,
from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
nvtx_range, sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerBase
SeqId = int
TargetSeqId = int
......@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
"""
def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
......@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
)
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
......@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This maps the scores of speculative tokens back to their original
sequences.
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors(
sampler_output=target_sampler_output,
batch_size=len(non_spec_indices) + num_scoring_tokens,
vocab_size=self._vocab_size,
device=self._device,
)
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
......
......@@ -6,7 +6,8 @@ import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.spec_decode.util import (maybe_mock_device_tensors,
sampler_output_to_torch)
from vllm.worker.worker import Worker
......@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
......@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
sampler_output = maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for step_output in sampler_output:
maybe_mock_device_tensors(
sampler_output=step_output,
batch_size=len(proposal_lens),
vocab_size=self._vocab_size,
device=self._device,
)
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
......
......@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
......@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
......@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
@classmethod
def from_workers(cls, proposer_worker: MultiStepWorker,
scorer_worker: WorkerBase) -> "SpecDecodeWorker":
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
# TODO(cade) disable strict mode for speedup.
rejection_sampler=RejectionSampler(strict_mode=True),
)
def __init__(
self,
proposer_worker: MultiStepWorker,
scorer_worker: Worker,
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
):
......@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.scorer_worker.init_device()
self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
......@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_spec_tokens: int,
num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
......@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list")
logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
......@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_spec_tokens,
k=num_lookahead_slots,
)
@nvtx_range("spec_decode_worker._run_no_spec")
......@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger.info("run proposer worker no spec")
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
return_python_output=False)
)
logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
......@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger.info("get spec proposals")
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
......@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals,
)
logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
......@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
logprobs={token_id: 0.0},
logprobs={token_id: Logprob(0.0)},
)
],
prompt_logprobs=None,
......
......@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return sampled_token_ids, sampled_token_probs
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
vocab_size: int, device: str) -> None:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values = [
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
]
assert all(v is None for v in values) or not any(v is None for v in values)
if not any(v is None for v in values):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
dim=-1)
sampler_output.sampled_token_ids = torch.randint(low=10,
high=100,
size=(batch_size, ),
dtype=torch.long,
device=device)
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
......
......@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
......@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
return output
# CPU worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
......
"""A Neuron worker class."""
from typing import List, Optional, Tuple
from typing import List, Tuple
import torch
import torch.distributed
......@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]:
) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
return []
output = self.model_runner.execute_model(seq_group_metadata_list)
return output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return [output]
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
......
......@@ -210,7 +210,9 @@ class Worker(WorkerBase):
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
num_lookahead_slots: int = 0,
) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
......@@ -235,11 +237,14 @@ class Worker(WorkerBase):
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
return output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return [output]
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
......
......@@ -40,12 +40,13 @@ class WorkerBase(ABC):
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError
@abstractmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment