Unverified Commit 15702038 authored by Lily Liu's avatar Lily Liu Committed by GitHub
Browse files

[Spec Decode] (1/2) Remove batch expansion (#8839)

parent 22f5851b
import functools
from typing import Callable, List
from typing import Callable, List, Optional
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
......@@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def process_outputs(self,
sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput],
is_async: bool = False) -> None:
is_async: bool = False) -> Optional[int]:
"""Append new tokens in the outputs to sequences in the sequence group.
This only supports sequence groups of size 1. It supports greater than
......@@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
Returns:
The number of tokens appended to the sequence. This is optional
because only speculative decode uses this return value.
"""
# Sequences can be in RUNNING or FINISHED_ABORTED state
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
......@@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# was already appended, so we only need to do the rest of the
# postprocessor: Detokenization + stopping logic
self._process_decode_and_stop(seq, sequence_group.sampling_params)
return None
else:
# Standard multi-step case
......@@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
]
assert valid_samples
self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
return self._process_seq_outputs(seq, valid_samples,
sequence_group.sampling_params)
def _process_decode_and_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
......@@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def _process_seq_outputs(self, seq: Sequence,
valid_samples: List[SequenceOutput],
sampling_params: SamplingParams) -> None:
sampling_params: SamplingParams) -> int:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
......@@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
len(output_token_ids))
if remaining_tokens < 0:
valid_samples = valid_samples[:remaining_tokens]
output_token_ids = output_token_ids[:remaining_tokens]
# Truncate any tokens after EOS. This is required as spec decode
......@@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for i in range(len(output_token_ids)):
if output_token_ids[i] == eos_token_id:
output_token_ids = output_token_ids[:i + 1]
valid_samples = valid_samples[:i + 1]
break
# Incrementally append tokens to the sequence, as if we had only one new
......@@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
token_id=output_token_id,
logprobs=output_logprob,
)
seq.data.update_num_computed_tokens(1)
self._process_decode_and_stop(seq, sampling_params)
if seq.is_finished():
break
return len(output_token_ids)
......@@ -912,7 +912,7 @@ def get_logprobs(
sampling_metadata: SamplingMetadata,
sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
"""Return sample logprobs and prompt logprobs.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and
......
......@@ -146,7 +146,7 @@ class SamplingMetadata:
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
query_lens: List[int],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
......@@ -194,7 +194,7 @@ class SamplingMetadata:
def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: Optional[List[int]],
query_lens: List[int],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
......@@ -284,7 +284,8 @@ def _prepare_seq_groups(
else:
# Decode
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
query_len = query_lens[i] if query_lens is not None else 1
sample_len = len(seq_ids) * query_len if do_sample else 0
if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
......@@ -440,14 +441,14 @@ class SamplingTensors:
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
min_ps += [min_p] * len(seq_ids)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
assert sample_lens >= len(seq_ids)
temperatures += [temperature] * sample_lens
top_ps += [top_p] * sample_lens
top_ks += [top_k] * sample_lens
min_ps += [min_p] * sample_lens
presence_penalties += [p] * sample_lens
frequency_penalties += [f] * sample_lens
repetition_penalties += [r] * sample_lens
if do_penalties:
for seq_group in sampling_metadata.seq_groups:
......
......@@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.worker.worker_base import WorkerBase
SeqId = int
TargetSeqId = int
......@@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
"""
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
......
......@@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
......
......@@ -5,6 +5,7 @@ from typing import Optional, Set
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.worker.worker_base import WorkerBase
@dataclass
......@@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
class SpeculativeScorer(ABC):
def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@abstractmethod
def score_proposals(
self,
......
from vllm.sequence import (ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
SeqId = int
TargetSeqId = int
class MQAScorer(SpeculativeScorer):
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
target_seq_group_metadata_list = []
target_seq_id_start = max(
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
all_proposal_tokens = proposals.proposal_token_ids.tolist()
for i, seq_group_metadata in enumerate(
execute_model_req.seq_group_metadata_list):
seq_data_dict = seq_group_metadata.seq_data
assert len(seq_data_dict) == 1
seq_id = next(iter(seq_data_dict.keys()))
seq_data: SequenceData = seq_data_dict[seq_id]
prompt_token_ids = seq_data.get_prompt_token_ids()
output_token_ids = seq_data.get_output_token_ids()
proposal_token_ids = all_proposal_tokens[i]
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
target_seq_id = target_seq_id_start + i
new_seq_data = SequenceData.from_seqs(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
)
new_seq_data.update_num_computed_tokens(
len(prompt_token_ids) + len(output_token_ids) - 1)
# Ensure that the new sequence has at least one token
# because we only use mqa scorer in the decoding stage.
assert len(output_token_ids) >= 1
new_seq_data_dict = {target_seq_id: new_seq_data}
new_seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
target_seq_group_metadata_list.append(new_seq_group_metadata)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
target_sampler_output = target_sampler_output[0]
bs, k = proposals.proposal_token_ids.shape
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
all_probs = target_sampler_output.sampled_token_probs.reshape(
bs, k + 1, self._vocab_size)
all_logprobs = target_sampler_output.logprobs.reshape(
bs, k + 1, self._vocab_size)
hidden_states = None
if target_sampler_output.hidden_states is not None:
hidden_states = target_sampler_output.hidden_states.reshape(
bs, (k + 1), -1)
return SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=hidden_states)
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
......@@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
......@@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
draft_token_acceptance_method=speculative_config.
......@@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
cls,
scorer_worker: Worker,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
......@@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info("Configuring SpecDecodeWorker with sampler=%s",
type(spec_decode_sampler))
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name(
) != "flash-attn":
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if ngram_prompt_lookup_max > 0:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer.")
if "model_config" in draft_worker_kwargs and \
draft_worker_kwargs["model_config"].max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info(
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
......@@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_mqa_scorer: bool = False,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
......@@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
types of sampler namely RejectionSampler and
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
instance of RejectionSampler or TypicalAcceptanceSampler.
disable_mqa_scorer: If set to True, disable the MQA scorer and use
the BatchExpansionTop1Scorer instead.
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
......@@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
# Lazy initialization.
self.scorer: SpeculativeScorer
self.disable_mqa_scorer = disable_mqa_scorer
# Hidden states from target model to pass to proposer
# in the subsequent step.
......@@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._metrics.init_gpu_tensors(self.rank)
self.spec_decode_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
scorer_cls = BatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
......
......@@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if inter_data.is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
seq_len = min(seq_len, context_len + token_chunk_size)
elif self.runner.scheduler_config.is_multi_step or \
self.runner.model_config.is_encoder_decoder_model:
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
else:
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
if inter_data.is_prompt:
tokens = seq_data.get_token_ids()
if context_len != 0 or seq_len < len(tokens):
tokens = tokens[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = seq_data.get_last_token_id()
tokens = seq_data.get_token_ids()[context_len:seq_len]
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
if isinstance(tokens, list):
inter_data.input_tokens[seq_idx].extend(tokens)
else:
inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.query_lens[seq_idx] = seq_len - context_len
if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment