Commit 0ecda6d1 authored by lizhigong's avatar lizhigong
Browse files

debug spec decode zero overhead

parent 01c30741
......@@ -699,7 +699,7 @@ def _sample_with_torch(
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
sampled_token_ids_ = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
......@@ -736,7 +736,8 @@ def _sample_with_torch(
probs[long_sample_indices],
max_n_in_batch,
seq_groups=seq_groups_arg)
sampled_token_ids_ = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \
......@@ -745,6 +746,7 @@ def _sample_with_torch(
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
print('###sampled_token_ids', sampled_token_ids_)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
......
......@@ -910,6 +910,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
print('###accepted_token_ids', accepted_token_ids)
# move kv_caches of selected tokens to right positions
if self.tree_decoding:
......@@ -1340,6 +1341,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
print('###sampler_output_list', sampler_output_list)
return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
......
......@@ -11,6 +11,7 @@ from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata,
SequenceOutput)
from vllm.zero_overhead.utils import is_zero_overhead
SeqId = int
......@@ -139,7 +140,6 @@ def split_batch_by_proposal_len(
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
for i, (seq_group, proposal_len) in enumerate(
......
......@@ -902,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions.
if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
print('###input_tokens', input_tokens)
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
......@@ -916,12 +917,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for idx in range(3):
mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size))
print('###mrope_input_positions', mrope_input_positions)
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
print('###input_positions', input_positions)
input_positions_tensor = async_tensor_h2d(input_positions,
torch.long,
self.runner.device,
......@@ -929,6 +932,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths.
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
print('###seq_lens', seq_lens)
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
......@@ -987,7 +991,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls(
ret = self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
token_types=token_types_tensor,
......@@ -1001,6 +1005,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
print('###model_input', ret)
return ret
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
......
......@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import is_zero_no_thread
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step
logger = init_logger(__name__)
......@@ -301,7 +301,10 @@ class ZeroOverheadEngine(LLMEngine):
) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None:
last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record()
self.q_recorder.put(self.last_record)
else:
......@@ -332,13 +335,18 @@ class ZeroOverheadEngine(LLMEngine):
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
if len(outputs) == 1:
for output in outputs:
self._advance_to_next_step(
outputs[0], seq_group_metadata_list,
output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = get_last_sampler()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs]
last_sampler = None
spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE:
last_sampler, _ = get_accepted_token_ids()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
except Exception as e:
print(f"thread_zero_overhead error : {e}")
......@@ -357,13 +365,19 @@ class ZeroOverheadEngine(LLMEngine):
virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_event.synchronize()
self._fix_last_step(
outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_event.synchronize()
self._fix_spec_decode_steps(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
......@@ -430,6 +444,33 @@ class ZeroOverheadEngine(LLMEngine):
sample.output_token = token_id
seq.fix_last_token_id(sample.output_token)
break
def _fix_spec_decode_steps(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]):
sample_out_list = self.async_d2h.tolist()
group_idx = 0
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
group_idx += 1
continue
if seq_group_metadata.do_sample:
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
remove_count = 0
for token_id in accept_token_ids:
if token_id == -1:
remove_count += 1
else:
seq.fix_last_token_id(token_id)
seq.remove_last_place_holder(remove_count)
group_idx += 1
def no_thread_step(self):
virtual_engine = 0
......
......@@ -11,7 +11,65 @@ from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.update_input import UpdateInputTokens
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
accepted_req_ids,
accepted_req_ids_len,
accepted_token_ids,
accepted_token_len,
chidren_req_ids,
chidren_req_ids_len,
input_tokens,
input_tokens_len,
input_positions,
seq_lens,
seq_lens_meta,
seq_lens_tensor,
slot_mapping,
seq_start_loc,
context_lens_tensor,
):
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
for seq_id_idx in range(chidren_req_ids_len / 2):
seq_id = chidren_req_ids_[2 * seq_id_idx]
for i in range(accepted_req_ids_len):
if seq_id == accepted_req_ids_[i]:
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
accepted_token_counter = 0
for j in range(accepted_token_len):
if accepted_token_ids_[j] == -1:
break
accepted_token_counter += 1
if accepted_token_counter == accepted_token_len:
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
else:
tl.store(input_tokens + seq_id_idx * 2, 0)
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
input_pos[0] = 0
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = -1
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = 1
input_pos[1] = input_pos[1] + 1
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
seq_start_loc_ = tl.zero_like(seq_start_loc)
for i in range(input_tokens_len):
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
......@@ -34,22 +92,80 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def build(self) -> ModelInputForGPU:
model_input = super().build()
print('###model_input', model_input)
last_sampler = get_last_sampler()
spec_step = get_spec_step()
last_step = get_spec_last_step()
if last_sampler is not None:
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
if spec_step == SpecStepKind.OTHER_PROPOSAL:
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if spec_step == SpecStepKind.SCORE_DECODE:
proposal_token_ids = get_proposal_token_ids()
shape = proposal_token_ids.shape
batch_size = shape[0]
proposal_len = shape[1]
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
for i in range(batch_size):
for j in range(proposal_len):
update_indices.append(i * (proposal_len + 1) + j + 1)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
if len(select_indices) > 0:
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
if spec_step == SpecStepKind.FIRST_PROPOSAL:
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
grid = [1, 1, 1]
_update_input_tokens[grid](
accept_seq_ids, accept_seq_ids.shape[0],
accept_token_ids, accept_token_ids.shape[1],
chidren_req_ids, chidren_req_ids.shape[0],
model_input.input_tokens, model_input.input_tokens.shape[0],
model_input.input_positions,
model_input.seq_lens,
model_input.attn_metadata.seq_lens_tensor,
model_input.attn_metadata.seq_lens,
model_input.attn_metadata.slot_mapping,
model_input.attn_metadata.seq_start_loc,
model_input.attn_metadata.context_lens_tensor,
)
print('###zero_model_input', model_input)
return model_input
......@@ -359,6 +359,7 @@ def _sample_with_torch(
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
print('###sampled_token_ids', last_sampler.sampled_token_ids_tensor)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType(
......
......@@ -19,6 +19,11 @@ class ZeroOverheadSequence(Sequence):
self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1
def remove_last_place_holder(self, count):
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
self.data._num_computed_tokens -= count
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len]
......
......@@ -15,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
SeqId = int
TargetSeqId = int
......@@ -48,8 +49,9 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
which sequences were ignored during scoring.
"""
proposal_lens_list = np.zeros(proposals.proposal_lens.shape, dtype=int).tolist() #zero_overhead todo fix
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist()
proposal_lens_list = get_proposal_lens_list()
record_proposal_token_ids(proposals.proposal_token_ids)
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
# Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [
......@@ -64,14 +66,11 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
proposal_lens_list=proposal_lens_list,
)
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
#print('###target_sampler_output', target_sampler_output)
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec(
......
......@@ -11,6 +11,7 @@ from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
......@@ -45,7 +46,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
For multi step worker, this indicator shall be True.
"""
print('###execute_model_req', execute_model_req)
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
......@@ -53,7 +53,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance(
......@@ -72,20 +71,20 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
for _ in range(sample_len):
print('###self.worker.execute_model', sample_len)
print('###expanded_request', expanded_request)
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
print('###model_output', model_output)
set_spec_step(SpecStepKind.OTHER_PROPOSAL)
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
set_spec_step(SpecStepKind.SCORE_DECODE)
filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens)
......
......@@ -27,8 +27,9 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
......@@ -49,7 +50,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname
from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine
......@@ -113,6 +114,90 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self._configure_model_sampler_for_spec_decode()
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
set_spec_step(SpecStepKind.PREFILL)
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden = [
sg for sg in execute_model_req.seq_group_metadata_list
if sg.do_sample
]
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
self.previous_logits = Logits(
logits, execute_model_req.seq_group_metadata_list)
else:
self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
[sampler_output])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return sampler_output_to_return
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
self,
......@@ -338,7 +423,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize tensor to CPU Python list.
#print('###accepted_token_ids_by_step', accepted_token_ids_by_step)
#accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
record_accepted_token_ids(accepted_token_ids, seq_ids)
# Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s
......@@ -428,8 +514,7 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append(
create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index]
[sequence_index],
token_id = 0,
token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[
......@@ -460,9 +545,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format.
print('###sampler_output_list', sampler_output_list)
return sampler_output_list
def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int],
......
......@@ -11,6 +11,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import record_proposal_lens_list
class ZeroOverheadTop1Proposer(Top1Proposer):
......@@ -48,13 +49,14 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed)
proposal_lens_list = [0 for i in range(batch_size)]
for indices in nonzero_proposal_len_indices:
proposal_lens_list[indices] = proposal_len
record_proposal_lens_list(proposal_lens_list)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device,
True)
proposal_len = [proposal_len for i in range(batch_size)]
proposal_len = async_tensor_h2d(proposal_len, torch.long,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
......@@ -74,10 +76,9 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
entire_proposal_tokens,
entire_proposal_probs,
)
proposal_lens_tensor = torch.zeros(batch_size,
dtype=torch.long,
device=self._device)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
self._device,
True)
return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
from enum import Enum
import os
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
......@@ -9,4 +10,55 @@ def is_zero_overhead():
return zero_overhead
def is_zero_no_thread():
return zero_no_thread and zero_overhead
\ No newline at end of file
return zero_no_thread and zero_overhead
class SpecStepKind(Enum):
KIND_DEFAULT = 0
PREFILL = 1
FIRST_PROPOSAL = 2
OTHER_PROPOSAL = 3
SCORE_DECODE = 4
class ZeroOverheadSpecContext():
def __init__(self):
self.step_kind = SpecStepKind.KIND_DEFAULT
self.last_step = SpecStepKind.KIND_DEFAULT
self.proposal_lens_list = None
self.proposal_token_ids = None
self.accepted_token_ids = None
self.accepted_seq_ids = None
spec_context = ZeroOverheadSpecContext()
def set_spec_step(_step):
global spec_context
spec_context.last_step = spec_context.step_kind
spec_context.step_kind = _step
def get_spec_step():
return spec_context.step_kind
def get_spec_last_step():
return spec_context.last_step
def record_proposal_lens_list(list):
global spec_context
spec_context.proposal_lens_list = list
def get_proposal_lens_list():
return spec_context.proposal_lens_list
def record_proposal_token_ids(tensor):
global spec_context
spec_context.proposal_token_ids = tensor
def get_proposal_token_ids():
return spec_context.proposal_token_ids
def record_accepted_token_ids(tensor, seq_ids):
global spec_context
spec_context.accepted_token_ids = tensor
spec_context.accepted_seq_ids = seq_ids
def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment