Commit 01c30741 authored by lizhigong's avatar lizhigong
Browse files

add spec decode zero overhead

parent b01c8270
...@@ -6,8 +6,8 @@ from contextlib import contextmanager ...@@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union from typing import Iterator, List, Optional, Union
import cloudpickle import cloudpickle
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead from vllm.zero_overhead.utils import is_zero_overhead
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
......
...@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, ...@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of) is_list_of)
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, ...@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob, CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.v0.utils import is_zero_overhead from vllm.zero_overhead.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module: ...@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module:
from vllm.v1.sample.sampler import Sampler as V1Sampler from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler() return V1Sampler()
if is_zero_overhead(): if is_zero_overhead():
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler from vllm.zero_overhead.sampler import ZeroOverheadSampler
return ZeroOverheadSampler() return ZeroOverheadSampler()
return Sampler() return Sampler()
......
...@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase ...@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -206,7 +207,10 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -206,7 +207,10 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device # Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle": if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True enable_lm_head_weight_load = True
if is_zero_overhead():
from vllm.zero_overhead.spec_decode.muti_step_worker import ZeroOverheadMultiStepWorker
proposer_worker = ZeroOverheadMultiStepWorker(**draft_worker_kwargs)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp": if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = \ num_spec_prefill_steps = \
...@@ -254,6 +258,20 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -254,6 +258,20 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the " "[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.") "target model is not running in eager mode.")
if is_zero_overhead():
from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
return ZeroOverheadSpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
else:
return SpecDecodeWorker( return SpecDecodeWorker(
proposer_worker, proposer_worker,
scorer_worker, scorer_worker,
......
...@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import ( ...@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.v0.utils import is_zero_overhead from vllm.zero_overhead.utils import is_zero_overhead
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead(): if is_zero_overhead():
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder _builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
......
...@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
from vllm.zero_overhead.v0.sampler import SampleRecorder, get_last_sampler from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence from vllm.zero_overhead.sequence import ZeroOverheadSequence
from vllm.zero_overhead.v0.stop_check import ZeroOverheadStopChecker from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.v0.tokenizer import ZeroOverheadDetokenizer from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.zero_overhead.v0.utils import is_zero_no_thread from vllm.zero_overhead.utils import is_zero_no_thread
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens from vllm.zero_overhead.update_input import UpdateInputTokens
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
......
from array import array
import numpy as np
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list = np.zeros(proposals.proposal_lens.shape, dtype=int).tolist() #zero_overhead todo fix
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
#print('###target_sampler_output', target_sampler_output)
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
self._device,
True)
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
self._device,
True)
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
\ No newline at end of file
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker_base import DelegateWorkerBase
class ZeroOverheadMultiStepWorker(MultiStepWorker):
def init_device(self) -> None:
self.worker.init_device()
self._proposer = ZeroOverheadTop1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
print('###execute_model_req', execute_model_req)
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
print('###self.worker.execute_model', sample_len)
print('###expanded_request', expanded_request)
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
print('###model_output', model_output)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
def _filter_model_output_zero_overhead(self,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
self.device,
True)
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
\ No newline at end of file
import os
import copy
from collections import defaultdict
from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.distributed.parallel_state import model_parallel_is_initialized
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SequenceGroupMetadata,
get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (Timer, create_logprobs_output,
create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
logger = init_logger(__name__)
class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
"""
# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_device()
self.proposer_worker.init_device()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self.scorer_worker.load_model()
self.proposer_worker.load_model()
if self._enable_lm_head_weight_load:
# NOTE(Shangming): gather lm_head weight when tp enabled
target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather(
self.scorer_worker.model_runner.model_runner.model.lm_head.\
weight.data,
dim=0,
)
self.proposer_worker.maybe_load_lm_head_weight(
target_lm_head_weight)
self._metrics.init_tensors(self.rank, device_type=self.device)
if model_parallel_is_initialized():
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)
else:
self.spec_decode_sampler.init_tensors(self.rank,
device_type=self.device)
scorer_cls: Type[SpeculativeScorer]
if self.disable_mqa_scorer:
scorer_cls = ZeroOverheadBatchExpansionTop1Scorer
logger.info("[Speculative Decoding] Use batch "
"expansion for scoring proposals.")
else:
scorer_cls = MQAScorer
logger.info(
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
if not self.tree_decoding:
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
else:
self.scorer = BatchExpansionTreeStyleScorer(
scorer_worker=self.scorer_worker,
device=self.device,
vocab_size=self._vocab_size)
self._configure_model_sampler_for_spec_decode()
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_scores: SpeculativeScores,
proposals: SpeculativeProposals,
max_proposal_len: int,
) -> Tuple[torch.Tensor, torch.Tensor, List[List[int]], List[int]]:
"""Determine which speculative tokens are accepted using the
probabilities of each token according to the proposer and scorer models.
Returns a tuple of Tensors, one for the accepted token ids and one for
the logprobs according to the scoring model.
"""
proposal_lens_list = proposals.proposal_lens
# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
original_indices = spec_indices + non_spec_indices
# Get probabilities of target model, including bonus tokens.
if non_spec_indices:
proposal_verifier_probs = proposal_scores.probs[spec_indices]
else:
proposal_verifier_probs = proposal_scores.probs
if self.tree_decoding:
retrieve_indices = proposals.retrieve_indices
proposal_verifier_probs = proposal_verifier_probs[:, retrieve_indices]
# Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
# Get bonus tokens from target model.
bonus_token_ids = proposal_scores.token_ids[:, -1:]
if non_spec_indices:
bonus_token_ids = bonus_token_ids[spec_indices, :]
# Get probabilities according to proposal method.
proposal_probs = proposals.proposal_probs if proposals.proposal_probs is not None else None
if proposal_probs is not None and non_spec_indices:
proposal_probs = proposal_probs[spec_indices]
# Get proposed tokens.
proposal_token_ids = proposals.proposal_token_ids
if non_spec_indices:
proposal_token_ids = proposal_token_ids[spec_indices]
# Get tree buffers.
cart_candidates = proposals.cart_candidates if proposals.cart_candidates is not None else None
if cart_candidates is not None and non_spec_indices:
cart_candidates = cart_candidates[spec_indices]
# Sampler arguments
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
if isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler):
sampler_extra_kwargs["cart_candidates"] = cart_candidates
sampler_extra_kwargs["best_candidates"] = []
sampler_extra_kwargs["accept_lengths"] = []
first_step_flags = []
for i, sgm in enumerate(seq_group_metadata_list):
seq = next(iter(sgm.seq_data.values()))
first_step_flags.append(True if seq.get_first_step_flag() else False)
sampler_extra_kwargs["first_step_flags"] = first_step_flags
accepted_token_ids = self.spec_decode_sampler(
target_with_bonus_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
**sampler_extra_kwargs,
)
# Append output tokens from non-speculative sequences to
# the accepted token ids tensor.
if not self.tree_decoding:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
1).clone()
else:
non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len).clone()
non_spec_token_ids[:, 1:] = -1
accepted_token_ids = torch.cat(
[accepted_token_ids, non_spec_token_ids])
logprobs = proposal_scores.logprobs
# Rearrange so that results are in the order of the original seq group
# metadata.
original_indices = async_tensor_h2d(original_indices, torch.int32,
self.device,
True)
accepted_token_ids[original_indices] = accepted_token_ids.clone()
# B x K+1 x D
hidden_states = proposal_scores.hidden_states
select_indices = None
accept_lengths = None
select_indices_list = []
if cart_candidates is None:
if hidden_states is not None:
# Only get terminal hidden states for next step
terminal_metadata = [
sg for sg in seq_group_metadata_list if sg.do_sample
]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,
hs_size) # b x 1 x d
second_last_token_hidden_states = hidden_states[:, -2] # b x d
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
# Store hidden states from target model for subsequent decode step
self.previous_hidden_states = HiddenStates(
hidden_states, terminal_metadata,
second_last_token_hidden_states)
else:
retrieve_indices = proposals.retrieve_indices
batch_size = len(seq_group_metadata_list)
best_candidates = sampler_extra_kwargs["best_candidates"]
accept_lengths = sampler_extra_kwargs["accept_lengths"]
# Contract hidden states based on accepted tokens
hs_size = hidden_states.shape[-1]
hidden_states = hidden_states.view(batch_size, -1, hs_size)
# Store logits from target model for subsequent proposal
logits = proposal_scores.logits
logits = logits.view(batch_size, -1, logits.shape[-1])
logits = logits[:, retrieve_indices] # [batch_size, retrieve_size, max_depth, vocab_size]
previous_logits_list = []
previous_hidden_state_list = []
retrieve_indices = retrieve_indices.cpu()
for i in range(batch_size):
logit = logits[i, best_candidates[i], accept_lengths[i]].unsqueeze(0)
previous_logits_list.append(logit)
select_indices = retrieve_indices[best_candidates[i], :accept_lengths[i]+1]
hidden_state = hidden_states[i, select_indices[-1]].unsqueeze(0)
select_indices_list.append(select_indices)
previous_hidden_state_list.append(hidden_state)
logits = torch.cat(previous_logits_list, dim=0)
self.previous_logits = Logits(logits, seq_group_metadata_list)
hidden_states = torch.cat(previous_hidden_state_list, dim=0) # [batch_size, 1, vocab_size]
self.previous_hidden_states = HiddenStates(hidden_states,
seq_group_metadata_list,)
return accepted_token_ids, logprobs, select_indices_list, accept_lengths
def _create_output_sampler_list(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
prompt_logprobs: Optional[
torch.Tensor], # shape: [nprompt_tokens, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
The output is padded with -1 tokens such that each sequence has
the same number of outputs.
"""
batch_size, num_steps = accepted_token_ids.shape
accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1)
if self._disable_logprobs:
# We are skipping the logprobs. Hence don't serialize the
# logprobs related tensors from the GPU. Instead create
# empty/dummy lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_dummy_logprob_lists(
batch_size, num_steps,
self.scorer_worker.model_config.max_logprobs)
else:
# Organize input tensors by step instead of by sequence.
target_logprobs_by_step = target_logprobs.transpose(0, 1)
# Serialize all tensors into Python lists.
(accepted_token_id_ranks_by_step,
accepted_token_id_logprobs_by_step,
topk_logprobs_by_step, topk_indices_by_step) =\
self._create_logprob_lists_from_tensors(
target_logprobs_by_step, accepted_token_ids_by_step,
self.scorer_worker.model_config.max_logprobs)
# Get the sequence ids and num_logprobs (sampling parameter) in the
# batch.
seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids(
seq_group_metadata_list)
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize tensor to CPU Python list.
#print('###accepted_token_ids_by_step', accepted_token_ids_by_step)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
# i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while
# terminal chunks will only have one generated token at time 0.
sampler_output_list: List[SamplerOutput] = []
# Prefills are not multi-step (return at most 1 token), in order to
# avoid padding or repetition to fit decodes, we separate them.
for i, sg in enumerate(seq_group_metadata_list):
if not sg.is_prompt:
# Requests are ordered as prefills|decodes=>no more prefills.
break
num_logprobs = num_logprobs_per_seq[i]
seq_kwargs = dict(token_id=-1,
token_id_logprob_rank=0,
token_id_logprob=-float('inf'),
topk_token_ids=[-1] * num_logprobs,
topk_logprobs=[-float('inf')] * num_logprobs,
seq_id=seq_ids[i])
# Terminal chunk, has token.
if sg.do_sample:
seq_kwargs.update(
dict(
token_id=accepted_token_ids[i][0].item(),
token_id_logprob_rank=accepted_token_id_ranks_by_step[
0][i],
token_id_logprob=accepted_token_id_logprobs_by_step[0]
[i],
topk_token_ids=topk_indices_by_step[0][i]
[:num_logprobs],
# output only so step is 0
topk_logprobs=topk_logprobs_by_step[0][i]
[:num_logprobs],
))
needs_plogs = (sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
plogs = None
if prompt_logprobs is not None:
# Even non-terminal prompt chunks can have logprobs here.
plogs = prompt_logprobs[i]
elif needs_plogs:
# Prompt logprobs are requested but `_disable_logprobs` is set.
seq_data = next(iter(sg.seq_data.values()))
# Get only the tokens in this chunk!
prompt_token_ids = seq_data.get_prompt_token_ids()
prompt_token_ids = prompt_token_ids[
seq_data.
_num_computed_tokens:seq_data._num_computed_tokens +
sg.token_chunk_size]
is_first_chunk = seq_data._num_computed_tokens == 0
# There's no prob generated for the first token in a sequence.
if is_first_chunk:
prompt_token_ids = prompt_token_ids[1:]
plogs = [
create_logprobs_output(
token_id=p_token_id,
token_id_logprob_rank=-1,
token_id_logprob=0.0,
topk_token_ids=[],
topk_logprobs=[],
) for p_token_id in prompt_token_ids
]
seq_kwargs.update(dict(prompt_logprobs=plogs))
sampler_output_list.append(
SamplerOutput(
outputs=[create_sequence_group_output(
**seq_kwargs)])) # type: ignore
# Decodes, create one SamplerOutput per-step (at most K+1).
for step_index in range(num_steps):
# if all(token_id == -1 for sg, token_id in zip(
# seq_group_metadata_list,
# accepted_token_ids_by_step[step_index])
# if not sg.is_prompt):
# break
step_output_token_ids: List[CompletionSequenceGroupOutput] = []
for sequence_index in range(batch_size):
seq_meta = seq_group_metadata_list[sequence_index]
# Prompts already processed above.
if seq_meta.is_prompt:
continue
# Each sequence may have a different num_logprobs; retrieve it.
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
step_index][sequence_index],
seq_id=seq_ids[sequence_index],
topk_token_ids=topk_indices_by_step[step_index]
[sequence_index][:num_logprobs],
topk_logprobs=topk_logprobs_by_step[step_index]
[sequence_index][:num_logprobs],
))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
# Populate the data structures needed to keep track of sequences with
# bonus tokens.
self._track_sequences_with_bonus_tokens(seq_ids,
request_ids_seq_ids_mapping,
accepted_token_ids_by_step)
maybe_rejsample_metrics = (
self._metrics.maybe_collect_rejsample_metrics(k))
if maybe_rejsample_metrics is not None and sampler_output_list:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
return sampler_output_list
def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int],
request_ids_seq_ids_mapping: Dict[str, Set[int]],
accepted_token_ids_by_step: List[List[int]]):
"""
Updates the internal data structures which keep track of sequences
which have been assigned bonus tokens in their last forward pass.
"""
for seq_index, seq_id in enumerate(seq_ids):
# last_token_id = accepted_token_ids_by_step[-1][seq_index]
# if last_token_id == -1:
# self._seq_with_bonus_token_in_last_step.discard(seq_id)
# else:
self._seq_with_bonus_token_in_last_step.add(seq_id)
for request_id, sequences in request_ids_seq_ids_mapping.items():
self._request_id_seq_id_mapping[request_id].update(sequences)
\ No newline at end of file
import os
from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d
class ZeroOverheadTop1Proposer(Top1Proposer):
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device,
True)
proposal_len = [proposal_len for i in range(batch_size)]
proposal_len = async_tensor_h2d(proposal_len, torch.long,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
...@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker ...@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus from vllm.sequence import SequenceStatus
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker): class ZeroOverheadStopChecker(StopChecker):
......
...@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams ...@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer): class ZeroOverheadDetokenizer(Detokenizer):
......
...@@ -10,11 +10,3 @@ def is_zero_overhead(): ...@@ -10,11 +10,3 @@ def is_zero_overhead():
def is_zero_no_thread(): def is_zero_no_thread():
return zero_no_thread and zero_overhead return zero_no_thread and zero_overhead
def UpdateInputTokens(input_tokens, last_sample, indices):
global _update_input_tokens_ptr
grid = [input_tokens.shape[0], 1, 1]
if _update_input_tokens_ptr is None:
_update_input_tokens_ptr = _update_input_tokens[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
else:
_update_input_tokens_ptr[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment