Commit 1bd3ae33 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip silu_mul_fp8_quant_deep_gemm_cuda and remove zero_overhead

parent 9bf1b213
...@@ -276,6 +276,9 @@ class ModelConfig: ...@@ -276,6 +276,9 @@ class ModelConfig:
override_pooler_config: Optional[Union[dict, PoolerConfig]] = None override_pooler_config: Optional[Union[dict, PoolerConfig]] = None
"""[DEPRECATED] Use `pooler_config` instead. This field will be removed in """[DEPRECATED] Use `pooler_config` instead. This field will be removed in
v0.12.0 or v1.0.0, whichever is sooner.""" v0.12.0 or v1.0.0, whichever is sooner."""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
# Multimodal config and init vars # Multimodal config and init vars
multimodal_config: Optional[MultiModalConfig] = None multimodal_config: Optional[MultiModalConfig] = None
...@@ -320,6 +323,7 @@ class ModelConfig: ...@@ -320,6 +323,7 @@ class ModelConfig:
factors.append(self.rope_scaling) factors.append(self.rope_scaling)
factors.append(self.rope_theta) factors.append(self.rope_theta)
factors.append(self.video_pruning_rate) factors.append(self.video_pruning_rate)
factors.append(self.enable_chunked_prefill)
# hf_config can control how the model looks! # hf_config can control how the model looks!
try: try:
......
...@@ -56,7 +56,6 @@ from vllm.v1.engine.llm_engine import LLMEngine ...@@ -56,7 +56,6 @@ from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
import vllm.envs as envs import vllm.envs as envs
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -300,12 +299,8 @@ class LLM: ...@@ -300,12 +299,8 @@ class LLM:
log_non_default_args(engine_args) log_non_default_args(engine_args)
# Create the Engine (autoselects V0 vs V1) # Create the Engine (autoselects V0 vs V1)
if envs.VLLM_ZERO_OVERHEAD: self.llm_engine = LLMEngine.from_engine_args(
self.llm_engine = ZeroOverheadEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
else:
self.llm_engine = LLMEngine.from_engine_args(
engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
self.engine_class = type(self.llm_engine) self.engine_class = type(self.llm_engine)
self.request_counter = Counter() self.request_counter = Counter()
......
...@@ -1840,8 +1840,7 @@ class FusedMoE(CustomOp): ...@@ -1840,8 +1840,7 @@ class FusedMoE(CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor)
e_score_correction_bias=e_score_correction_bias)
if indices_type is not None: if indices_type is not None:
topk_ids = topk_ids.to(dtype=indices_type) topk_ids = topk_ids.to(dtype=indices_type)
elif e_score_correction_bias is not None: elif e_score_correction_bias is not None:
......
...@@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, TypeVar, Union ...@@ -16,7 +16,6 @@ from typing import Any, Callable, Optional, TypeVar, Union
import msgspec import msgspec
from vllm import envs from vllm import envs
from vllm.zero_overhead.v1.core import engine_core_step
import zmq import zmq
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
...@@ -277,8 +276,6 @@ class EngineCore: ...@@ -277,8 +276,6 @@ class EngineCore:
Returns tuple of outputs and a flag indicating whether the model Returns tuple of outputs and a flag indicating whether the model
was executed. was executed.
""" """
if envs.VLLM_ZERO_OVERHEAD:
return engine_core_step(self)
# Check for any requests remaining in the scheduler - unfinished, # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch. # or finished and not yet removed from the batch.
......
...@@ -1458,25 +1458,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1458,25 +1458,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# [0, 1, 2, 5, 6, 9] # [0, 1, 2, 5, 6, 9]
target_logits_indices += arange target_logits_indices += arange
if envs.VLLM_ZERO_OVERHEAD: # if envs.VLLM_ZERO_OVERHEAD:
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to( # cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).pin_memory().to(
self.device, non_blocking=True) # self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device, # logits_indices = torch.from_numpy(logits_indices).pin_memory().to(self.device,
non_blocking=True) # non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to( # target_logits_indices = torch.from_numpy(target_logits_indices).pin_memory().to(
self.device, non_blocking=True) # self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to( # bonus_logits_indices = torch.from_numpy(bonus_logits_indices).pin_memory().to(
self.device, non_blocking=True) # self.device, non_blocking=True)
else: # else:
# TODO: Optimize the CPU -> GPU copy. # TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True) self.device, non_blocking=True)
logits_indices = torch.from_numpy(logits_indices).to(self.device, logits_indices = torch.from_numpy(logits_indices).to(self.device,
non_blocking=True) non_blocking=True)
target_logits_indices = torch.from_numpy(target_logits_indices).to( target_logits_indices = torch.from_numpy(target_logits_indices).to(
self.device, non_blocking=True) self.device, non_blocking=True)
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
self.device, non_blocking=True) self.device, non_blocking=True)
# Compute the draft token ids. # Compute the draft token ids.
......
...@@ -34,8 +34,6 @@ from vllm.v1.utils import report_usage_stats ...@@ -34,8 +34,6 @@ from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
from vllm.zero_overhead.utils import zero_overhead_stream
from vllm.zero_overhead.v1.gpu_model_runner import V1ZeroModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -200,13 +198,8 @@ class Worker(WorkerBase): ...@@ -200,13 +198,8 @@ class Worker(WorkerBase):
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Construct the model runner # Construct the model runner
if envs.VLLM_ZERO_OVERHEAD: self.model_runner: GPUModelRunner = GPUModelRunner(
logger.info('use zero overhead model_runner') self.vllm_config, self.device)
self.model_runner: GPUModelRunner = V1ZeroModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
if self.rank == 0: if self.rank == 0:
# If usage stat is enabled, collect relevant info. # If usage stat is enabled, collect relevant info.
...@@ -451,14 +444,8 @@ class Worker(WorkerBase): ...@@ -451,14 +444,8 @@ class Worker(WorkerBase):
all_gather_group=get_tp_group(), all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)) all_gather_tensors=all_gather_tensors))
if envs.VLLM_ZERO_OVERHEAD: output = self.model_runner.execute_model(scheduler_output,
use_stream = zero_overhead_stream(self.device) intermediate_tensors)
with torch.cuda.stream(use_stream):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
else:
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
return output return output
......
...@@ -52,8 +52,7 @@ class WorkerBase: ...@@ -52,8 +52,7 @@ class WorkerBase:
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers. communicate request metadata to other workers.
""" """
# TODO
model_input: Optional[ModelRunnerInputBase] = None
tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1') tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
def __init__( def __init__(
......
This diff is collapsed.
import torch
import itertools
from typing import List, Optional, Set
from vllm.lora.layers import LoRAMapping
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
accepted_req_ids,
accepted_req_ids_len,
accepted_token_ids,
accepted_token_len,
chidren_req_ids,
chidren_req_ids_len,
input_tokens,
input_tokens_len,
input_positions,
seq_lens,
seq_lens_meta,
seq_lens_tensor,
slot_mapping,
seq_start_loc,
context_lens_tensor,
):
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
for seq_id_idx in range(chidren_req_ids_len / 2):
seq_id = chidren_req_ids_[2 * seq_id_idx]
for i in range(accepted_req_ids_len):
if seq_id == accepted_req_ids_[i]:
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
accepted_token_counter = 0
for j in range(accepted_token_len):
if accepted_token_ids_[j] == -1:
break
accepted_token_counter += 1
if accepted_token_counter == accepted_token_len:
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
else:
tl.store(input_tokens + seq_id_idx * 2, 0)
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
input_pos[0] = 0
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = -1
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = 1
input_pos[1] = input_pos[1] + 1
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
seq_start_loc_ = tl.zero_like(seq_start_loc)
for i in range(input_tokens_len):
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def __init__(self, runner, finished_requests_ids = None):
super().__init__(runner, finished_requests_ids)
self.req_ids = []
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
self.req_ids.clear()
return super().prepare(finished_requests_ids)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
seq_ids = list(seq_ids)
for seq_idx in range(n_seqs):
self.req_ids.append(seq_ids[seq_idx])
return super().add_seq_group(seq_group_metadata)
def build(self) -> ModelInputForGPU:
model_input = super().build()
last_sampler = get_last_sampler()
spec_step = get_spec_step()
last_step = get_spec_last_step()
if last_sampler is not None:
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
query_idx = 0
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(query_idx)
break
query_idx += model_input.query_lens[i]
if len(select_indices) > 0 and last_sampler.sampled_token_ids_tensor is not None:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
if spec_step == SpecStepKind.OTHER_PROPOSAL:
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if spec_step == SpecStepKind.SCORE_DECODE:
proposal_token_ids = get_proposal_token_ids()
shape = proposal_token_ids.shape
batch_size = shape[0]
proposal_len = shape[1]
update_indices = []
for i in range(batch_size):
for j in range(proposal_len):
update_indices.append(i * (proposal_len + 1) + j + 1)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
if spec_step == SpecStepKind.FIRST_PROPOSAL:
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
grid = [1, 1, 1]
_update_input_tokens[grid](
accept_seq_ids, accept_seq_ids.shape[0],
accept_token_ids, accept_token_ids.shape[1],
chidren_req_ids, chidren_req_ids.shape[0],
model_input.input_tokens, model_input.input_tokens.shape[0],
model_input.input_positions,
model_input.seq_lens,
model_input.attn_metadata.seq_lens_tensor,
model_input.attn_metadata.seq_lens,
model_input.attn_metadata.slot_mapping,
model_input.attn_metadata.seq_start_loc,
model_input.attn_metadata.context_lens_tensor,
)
return model_input
This diff is collapsed.
from typing import Union
from vllm.sequence import Sequence
from typing import Sequence as GenericSequence
class ZeroOverheadSequence(Sequence):
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
self.effective_output_len : int = 0
def fix_last_token_id(self, token_id: int) -> None:
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
if effect_offset < 0:
self.data._output_token_ids[effect_offset] = token_id
if len(self.data._new_appended_tokens) >= effect_offset * -1:
self.data._new_appended_tokens[effect_offset] = token_id
self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1
def remove_last_place_holder(self, count):
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
self.data._num_computed_tokens -= count
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len]
def zero_overhead_get_output_len(self) -> int:
return self.effective_output_len
def zero_overhead_get_last_token_id(self) -> int:
if self.effective_output_len == 0:
return self.data._prompt_token_ids[-1]
return self.data._output_token_ids[self.effective_output_len - 1]
def zero_overhead_get_len(self) -> int:
return self.effective_output_len + len(self.data._prompt_token_ids)
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.zero_overhead_get_output_token_ids()
output_len = self.zero_overhead_get_output_len()
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len
# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[self.effective_output_len - 1]
if num_new_tokens == 0:
return []
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
from array import array
import numpy as np
from itertools import chain, count
from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals(
self,
execute_model_req: ExecuteModelRequest,
proposals: SpeculativeProposals,
) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model.
This converts each input sequence to a set of k+1 target sequences. The
target sequences have the unique continuations to be scored and a
unique sequence ID that is different from all input sequence ids.
If a speculative sequence length would exceed the max model length, then
no speculation is produced for that sequence.
Args:
execute_model_req: The execution request.
proposals: The speculative proposals to score.
Returns:
SpeculativeScores: The scores of each speculative token, along with
which sequences were ignored during scoring.
"""
proposal_lens_list = get_proposal_lens_list()
record_proposal_token_ids(proposals.proposal_token_ids)
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
proposals for proposals in proposal_token_ids_list
if VLLM_INVALID_TOKEN_ID not in proposals
]
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
return self._contract_batch(
execute_model_req.seq_group_metadata_list,
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
def _contract_non_speculative(
self, scores: SpeculativeScores,
seq_group_metadata_list: List[SequenceGroupMetadata],
non_spec_indices: List[int], non_spec_outputs: SpeculativeScores,
has_prompt_log: bool) -> SpeculativeScores:
"""
Augment input `scores` with non-speculative requests outputs.
This includes decode requests with speculation turned off, as well
as prefill requests when `enable_chunked_prefill` is set.
For the latter, prefills are further separated into terminal and
non-terminal chunks (from which no token is sampled).
"""
if not non_spec_indices:
return scores
if has_prompt_log:
# When prompt_logprobs is enabled, prefills yield output token
# (and respective prob) in the last entry (prompt|out):
# [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..].
# With chunked prefill, non-terminal chunks have -1 on each
# position: they're still picked, but they're discarded later.
seq_meta = seq_group_metadata_list
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1
for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(
range(len(non_spec_outputs.token_ids)))
nospec_sampled_token_idxs = async_tensor_h2d(nospec_sampled_token_idxs, torch.int32,
self._device,
True)
non_spec_indices = async_tensor_h2d(non_spec_indices, torch.int32,
self._device,
True)
scores.token_ids[non_spec_indices, :1] = \
non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1)
scores.probs[non_spec_indices, :1, :] = \
non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1)
scores.logprobs[non_spec_indices, :1, :] = \
non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1)
if scores.hidden_states is not None:
assert non_spec_outputs.hidden_states is not None
scores.hidden_states[non_spec_indices, :1, :] = \
non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1)
return scores
\ No newline at end of file
import copy
import weakref
from typing import Dict, List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker_base import DelegateWorkerBase
class ZeroOverheadMultiStepWorker(MultiStepWorker):
def init_device(self) -> None:
self.worker.init_device()
self._proposer = ZeroOverheadTop1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
set_spec_step(SpecStepKind.SCORE_DECODE)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
def _filter_model_output_zero_overhead(self,
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
model to retain the outputs of only those sequences indicated by the
provided indices.
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model
outputs for the specified indices.
"""
indices_of_seq_with_bonus_tokens = async_tensor_h2d(output_indices_to_retain, torch.int32,
self.device,
True)
return [
SamplerOutput(
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_probs is not None
else None),
logprobs=(
expanded_batch_output.logprobs[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.logprobs is not None else None),
sampled_token_ids=(expanded_batch_output.
sampled_token_ids[indices_of_seq_with_bonus_tokens]
if expanded_batch_output.sampled_token_ids
is not None else None))
for expanded_batch_output in expanded_batch_outputs
]
\ No newline at end of file
This diff is collapsed.
import os
from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import record_proposal_lens_list
class ZeroOverheadTop1Proposer(Top1Proposer):
def _merge_outputs(
self,
batch_size: int,
proposal_len: int,
maybe_sampler_output: Optional[List[SamplerOutput]],
proposal_lens: List[int],
nonzero_proposal_len_indices: List[int],
sampler_transposed: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens = torch.tensor(-1,
dtype=torch.long,
device=self._device).expand(
batch_size, proposal_len)
proposal_probs = torch.tensor(0,
dtype=torch.float32,
device=self._device).expand(
batch_size, proposal_len,
self._vocab_size)
proposal_lens_tensor = torch.tensor(0,
dtype=torch.long,
device=self._device).expand(
len(proposal_lens))
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
proposal_lens_list = [0 for i in range(batch_size)]
for indices in nonzero_proposal_len_indices:
proposal_lens_list[indices] = proposal_len
record_proposal_lens_list(proposal_lens_list)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens = proposal_tokens.new_full(
size=(batch_size, *proposal_tokens.shape[1:]),
fill_value=-1,
)
entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
entire_proposal_probs = proposal_probs.new_zeros(
batch_size,
*proposal_probs.shape[1:],
)
entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
proposal_tokens, proposal_probs = (
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
self._device,
True)
return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
from typing import Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceStatus
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadStopChecker(StopChecker):
def __init__(self, max_model_len, get_tokenizer_for_seq):
super().__init__(max_model_len, get_tokenizer_for_seq)
def maybe_stop_sequence(
self,
seq: ZeroOverheadSequence,
new_char_count: int,
sampling_params: SamplingParams,
lora_req: Optional[LoRARequest] = None,
) -> None:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
from vllm.sampling_params import SamplingParams
from vllm.sequence import VLLM_INVALID_TOKEN_ID
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
from vllm.zero_overhead.sequence import ZeroOverheadSequence
class ZeroOverheadDetokenizer(Detokenizer):
def __init__(self, tokenizer_group):
super().__init__(tokenizer_group)
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length = seq.get_prompt_len() + seq.effective_output_len
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if seq.tokens is None:
(seq.tokens, seq.prefix_offset,
seq.read_offset) = convert_prompt_ids_to_tokens(
tokenizer=tokenizer,
prompt_ids=all_input_ids[:-1],
skip_special_tokens=prms.skip_special_tokens,
)
(new_tokens, new_decoded_token_text, prefix_offset,
read_offset) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.spaces_between_special_tokens,
)
# Decode logprobs
logprobs = seq.output_logprobs[-1]
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
# If the token was generated this iteration,
# use the provided text.
if token_id == token_id_generated_this_iteration:
sample_logprob.decoded_token = new_decoded_token_text
continue
if (sample_logprob.decoded_token is None
and token_id != VLLM_INVALID_TOKEN_ID):
all_input_ids_with_logprob = previous_tokens + [token_id]
(_, new_text, _, _) = detokenize_incrementally(
tokenizer=tokenizer,
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
return len(new_decoded_token_text)
\ No newline at end of file
from enum import Enum
import os
import torch
import vllm.envs as envs
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
def is_zero_no_thread():
return zero_no_thread and envs.VLLM_ZERO_OVERHEAD
class SpecStepKind(Enum):
KIND_DEFAULT = 0
PREFILL = 1
FIRST_PROPOSAL = 2
OTHER_PROPOSAL = 3
SCORE_DECODE = 4
class ZeroOverheadSpecContext():
def __init__(self):
self.step_kind = SpecStepKind.KIND_DEFAULT
self.last_step = SpecStepKind.KIND_DEFAULT
self.proposal_lens_list = None
self.proposal_token_ids = None
self.accepted_token_ids = None
self.accepted_seq_ids = None
spec_context = ZeroOverheadSpecContext()
def set_spec_step(_step):
global spec_context
spec_context.last_step = spec_context.step_kind
spec_context.step_kind = _step
def get_spec_step():
return spec_context.step_kind
def get_spec_last_step():
return spec_context.last_step
def record_proposal_lens_list(list):
global spec_context
spec_context.proposal_lens_list = list
def get_proposal_lens_list():
return spec_context.proposal_lens_list
def record_proposal_token_ids(tensor):
global spec_context
spec_context.proposal_token_ids = tensor
def get_proposal_token_ids():
return spec_context.proposal_token_ids
def record_accepted_token_ids(tensor, seq_ids):
global spec_context
spec_context.accepted_token_ids = tensor
spec_context.accepted_seq_ids = seq_ids
def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream = {}
def zero_overhead_stream(target_device):
"""Asynchronously create a tensor and copy it from host to device."""
if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device]
import torch
from collections import defaultdict
from typing import Optional
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.zero_overhead.v1.outputs import ZeroV1ModelRunnerOutput
requsets_valid_token_len = {}
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None,
use_valid_token_len:bool = False) -> bool:
if use_valid_token_len:
if request.request_id not in requsets_valid_token_len:
requsets_valid_token_len[request.request_id] = 0
return False
valid_output_len = requsets_valid_token_len[request.request_id]
else:
valid_output_len = request.num_output_tokens
valid_num_tokens = request.num_prompt_tokens + valid_output_len
if (valid_num_tokens >= max_model_len
or valid_output_len >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[valid_output_len - 1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False
def zero_overhead_update_from_output(scheduler:Scheduler,
scheduler_output: SchedulerOutput,
model_runner_output: ZeroV1ModelRunnerOutput):
global requsets_valid_token_len
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
# fix last model out in zero overhead
if model_runner_output.fix_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
generated_token_ids = model_runner_output.fix_sampled_token_ids[req_idx]
if req_id not in requsets_valid_token_len:
requsets_valid_token_len[req_id] = 0
valid_output_len = requsets_valid_token_len[req_id]
fix_offset = valid_output_len - request.num_output_tokens
if isinstance(generated_token_ids, int):
request._output_token_ids[fix_offset] = generated_token_ids
request._all_token_ids[fix_offset] = generated_token_ids
requsets_valid_token_len[req_id] += 1
generated_token_ids = [generated_token_ids]
else:
valid_output_end = valid_output_len + len(generated_token_ids) - request.num_output_tokens
if valid_output_end == 0:
request._output_token_ids[fix_offset : ] = generated_token_ids
request._all_token_ids[fix_offset : ] = generated_token_ids
else:
request._output_token_ids[fix_offset : valid_output_end] = generated_token_ids
request._all_token_ids[fix_offset : valid_output_end] = generated_token_ids
requsets_valid_token_len[req_id] += len(generated_token_ids)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
for num_new, output_token_id in enumerate(new_token_ids, 1):
stopped = check_stop(request, scheduler.max_model_len, True)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_idx]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output, True)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_idx, req_idx + 1)
if new_token_ids and scheduler.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:
assert not prompt_logprobs_tensors
# fix last model out in zero overhead
if model_runner_output.fix_draft_req_ids is not None:
for req_idx, req_id in enumerate(model_runner_output.fix_draft_req_ids):
if req_id not in scheduler.requests:
continue
request = scheduler.requests[req_id]
# Add newly generated spec token ids to the request.
if model_runner_output.fix_draft_tokens_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
model_runner_output.fix_draft_tokens_ids[req_idx])
else:
request.spec_token_ids = model_runner_output.fix_draft_tokens_ids[req_idx]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in scheduler.running:
req_id = request.request_id
if request.is_finished():
if req_id in requsets_valid_token_len:
requsets_valid_token_len.pop(req_id)
continue
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
new_running.append(request)
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = scheduler.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
scheduler._free_encoder_inputs(request)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
if model_runner_output.is_output_valid:
stopped = check_stop(request, scheduler.max_model_len,
False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
if model_runner_output.is_output_valid:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, scheduler.max_model_len,
pooler_output,
False)
if stopped:
kv_transfer_params = scheduler._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and scheduler.structured_output_manager.should_advance(
request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if scheduler.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
if model_runner_output.is_output_valid:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
if stopped:
if req_id in requsets_valid_token_len:
requsets_valid_token_len.pop(req_id)
else:
new_running.append(request)
scheduler.running = new_running
# KV Connector: update state for finished KV Transfers.
scheduler._update_from_kv_xfer_finished(model_runner_output)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
client_index: EngineCoreOutputs(outputs=outs)
for client_index, outs in outputs.items()
}
finished_req_ids = scheduler.finished_req_ids_dict
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(
finished_requests=finished_set)
finished_req_ids.clear()
if engine_core_outputs:
# Return stats to only one of the front-ends.
next(iter(engine_core_outputs.values())).scheduler_stats = (
scheduler.make_stats(spec_decoding_stats))
return engine_core_outputs
def engine_core_step(core) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not core.scheduler.has_requests():
return {}, False
scheduler_output = core.scheduler.schedule()
model_output = core.execute_model(scheduler_output)
if isinstance(model_output, ZeroV1ModelRunnerOutput):
engine_core_outputs = zero_overhead_update_from_output(core.scheduler,
scheduler_output, model_output) # type: ignore
else:
engine_core_outputs = core.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
\ No newline at end of file
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, EagleProposer
class V1ZeroEagleProposer(EagleProposer):
def __init__(self, vllm_config, device, runner=None):
super().__init__(vllm_config, device, runner)
self.spec_scheduler_max_num_tokens = 0
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
# [batch_size]
sampling_metadata: SamplingMetadata,
decoding: bool = False,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
max_query_len = self.spec_scheduler_max_num_tokens
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens,
seq_lens=seq_lens,
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
slot_mapping=target_slot_mapping,
spec_layer_decoding=decoding
)
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata
)
else:
raise ValueError(f"Unsupported method: {self.method}")
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if (decoding and self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
if attn_metadata.decode is not None:
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
skip_cuda_graphs=not decoding):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
if self.method == "deepseek_mtp":
hidden_states = last_hidden_states[last_token_indices]
else:
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.num_decodes = batch_size
attn_metadata.num_decode_tokens = batch_size
attn_metadata.num_prefills = 0
block_table = self.runner.attn_metadata_builders[0].block_table.get_device_tensor()[:batch_size, ...]
attn_metadata.decode = self.runner.attn_metadata_builders[0]._build_decode(
block_table_tensor=block_table,
seq_lens=seq_lens,
)
for i in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
if isinstance(attn_metadata, MLACommonMetadata):
attn_metadata.decode.seq_lens += 1
else:
attn_metadata.seq_lens += 1
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = block_table.gather(dim=1,
index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if (self.use_full_cuda_graph
and batch_size <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
if self.method in ["eagle", "eagle3"]:
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:batch_size] = (
attn_metadata.slot_mapping)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size +
1] = (
attn_metadata
.
query_start_loc
)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
elif self.method == "deepseek_mtp":
self.attn_metadata_cudagraph.num_actual_tokens = (
attn_metadata.num_actual_tokens)
self.attn_metadata_cudagraph.slot_mapping[:attn_metadata.num_decode_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.num_decodes = (
attn_metadata.num_decodes)
self.attn_metadata_cudagraph.num_decode_tokens = (
attn_metadata.num_decode_tokens)
self.attn_metadata_cudagraph.num_prefills = (
attn_metadata.num_prefills)
self.attn_metadata_cudagraph.decode.seq_lens[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.seq_lens)
if i == 0:
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.decode.block_table[:attn_metadata.num_decode_tokens] = (
attn_metadata.decode.block_table)
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states[:batch_size]
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment