Commit 01c30741 authored by lizhigong's avatar lizhigong
Browse files

add spec decode zero overhead

parent b01c8270
......@@ -6,8 +6,8 @@ from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
import zmq
from vllm import AsyncEngineArgs, SamplingParams
......
......@@ -43,8 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
from vllm.zero_overhead.v0.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.v0.utils import is_zero_overhead
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__)
......
......@@ -21,7 +21,7 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.zero_overhead.v0.utils import is_zero_overhead
from vllm.zero_overhead.utils import is_zero_overhead
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
......@@ -40,7 +40,7 @@ def get_sampler() -> torch.nn.Module:
from vllm.v1.sample.sampler import Sampler as V1Sampler
return V1Sampler()
if is_zero_overhead():
from vllm.zero_overhead.v0.sampler import ZeroOverheadSampler
from vllm.zero_overhead.sampler import ZeroOverheadSampler
return ZeroOverheadSampler()
return Sampler()
......
......@@ -54,6 +54,7 @@ from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.zero_overhead.utils import is_zero_overhead
logger = init_logger(__name__)
......@@ -206,8 +207,11 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if is_zero_overhead():
from vllm.zero_overhead.spec_decode.muti_step_worker import ZeroOverheadMultiStepWorker
proposer_worker = ZeroOverheadMultiStepWorker(**draft_worker_kwargs)
else:
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = \
draft_model_config.hf_config.n_predict
......@@ -254,17 +258,31 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
if is_zero_overhead():
from vllm.zero_overhead.spec_decode.spec_decode_worker import ZeroOverheadSpecDecodeWorker
return ZeroOverheadSpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
else:
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
def __init__(
self,
......
......@@ -60,7 +60,7 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.zero_overhead.v0.utils import is_zero_overhead
from vllm.zero_overhead.utils import is_zero_overhead
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -1638,7 +1638,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
ModelInputForGPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
if is_zero_overhead():
from vllm.zero_overhead.v0.model_runner import ZeroOverheadModelInputForGpuBuilder
from vllm.zero_overhead.model_runner import ZeroOverheadModelInputForGpuBuilder
_builder_cls = ZeroOverheadModelInputForGpuBuilder
def make_model_input_from_broadcasted_tensor_dict(
......
......@@ -33,14 +33,14 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.version import __version__ as VLLM_VERSION
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled
from vllm.utils import resolve_obj_by_qualname, weak_bind, Counter
from vllm.zero_overhead.v0.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
from vllm.zero_overhead.v0.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.v0.tokenizer import ZeroOverheadDetokenizer
from vllm.zero_overhead.sampler import SampleRecorder, get_last_sampler
from vllm.zero_overhead.sequence import ZeroOverheadSequence
from vllm.zero_overhead.stop_check import ZeroOverheadStopChecker
from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
from vllm.zero_overhead.v0.utils import is_zero_no_thread
from vllm.zero_overhead.utils import is_zero_no_thread
logger = init_logger(__name__)
......
......@@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.v0.sampler import get_last_sampler
from vllm.zero_overhead.v0.update_input import UpdateInputTokens
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.update_input import UpdateInputTokens
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
......
from array import array
import numpy as np
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list = np.zeros(proposals.proposal_lens.shape, dtype=int).tolist() #zero_overhead todo fix
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist()
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
#print('###target_sampler_output', target_sampler_output)
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
self._device,
True)
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
self._device,
True)
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
\ No newline at end of file
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker_base import DelegateWorkerBase
class ZeroOverheadMultiStepWorker(MultiStepWorker):
def init_device(self) -> None:
self.worker.init_device()
self._proposer = ZeroOverheadTop1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
print('###execute_model_req', execute_model_req)
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
print('###self.worker.execute_model', sample_len)
print('###expanded_request', expanded_request)
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
print('###model_output', model_output)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
def _filter_model_output_zero_overhead(self,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
self.device,
True)
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
\ No newline at end of file
This diff is collapsed.
import os
from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d
class ZeroOverheadTop1Proposer(Top1Proposer):
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device,
True)
proposal_len = [proposal_len for i in range(batch_size)]
proposal_len = async_tensor_h2d(proposal_len, torch.long,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
......@@ -5,7 +5,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker):
......
......@@ -4,7 +4,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.v0.sequence import ZeroOverheadSequence
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer):
......
......@@ -9,12 +9,4 @@ def is_zero_overhead():
return zero_overhead
def is_zero_no_thread():
return zero_no_thread and zero_overhead
def UpdateInputTokens(input_tokens, last_sample, indices):
global _update_input_tokens_ptr
grid = [input_tokens.shape[0], 1, 1]
if _update_input_tokens_ptr is None:
_update_input_tokens_ptr = _update_input_tokens[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
else:
_update_input_tokens_ptr[grid](last_sample, input_tokens, indices, input_tokens.shape[0])
\ No newline at end of file
return zero_no_thread and zero_overhead
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment