Commit 0ecda6d1 authored by lizhigong's avatar lizhigong
Browse files

debug spec decode zero overhead

parent 01c30741
...@@ -699,7 +699,7 @@ def _sample_with_torch( ...@@ -699,7 +699,7 @@ def _sample_with_torch(
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
sampled_token_ids_ = greedy_samples.unsqueeze(-1)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
...@@ -736,7 +736,8 @@ def _sample_with_torch( ...@@ -736,7 +736,8 @@ def _sample_with_torch(
probs[long_sample_indices], probs[long_sample_indices],
max_n_in_batch, max_n_in_batch,
seq_groups=seq_groups_arg) seq_groups=seq_groups_arg)
sampled_token_ids_ = \
multinomial_samples[sampling_type].to(torch.long)
if sampled_token_ids_tensor is not None: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \ sampled_token_ids_tensor[long_sample_indices] = \
...@@ -745,6 +746,7 @@ def _sample_with_torch( ...@@ -745,6 +746,7 @@ def _sample_with_torch(
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
print('###sampled_token_ids', sampled_token_ids_)
# Encapsulate arguments for computing Pythonized sampler # Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise. # results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType( maybe_deferred_args = SampleResultArgsType(
......
...@@ -910,6 +910,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -910,6 +910,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens( accepted_token_ids, target_logprobs, select_indices_list, accept_lengths = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores, execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots) proposals, execute_model_req.num_lookahead_slots)
print('###accepted_token_ids', accepted_token_ids)
# move kv_caches of selected tokens to right positions # move kv_caches of selected tokens to right positions
if self.tree_decoding: if self.tree_decoding:
...@@ -1340,6 +1341,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -1340,6 +1341,7 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
self._maybe_log_stage_times(*stage_times) self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when # First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format. # chunked prefill is enabled, the rest is decodes in multi-step format.
print('###sampler_output_list', sampler_output_list)
return sampler_output_list return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
......
...@@ -11,6 +11,7 @@ from vllm.platforms import current_platform ...@@ -11,6 +11,7 @@ from vllm.platforms import current_platform
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SequenceGroupMetadata, PromptLogprobs, SequenceGroupMetadata,
SequenceOutput) SequenceOutput)
from vllm.zero_overhead.utils import is_zero_overhead
SeqId = int SeqId = int
...@@ -139,7 +140,6 @@ def split_batch_by_proposal_len( ...@@ -139,7 +140,6 @@ def split_batch_by_proposal_len(
zero or not. We should remove this once vLLM supports per-sequence proposal zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch. lens in a batch.
""" """
nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], [])
for i, (seq_group, proposal_len) in enumerate( for i, (seq_group, proposal_len) in enumerate(
......
...@@ -902,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -902,6 +902,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions. # Tokens and positions.
if cuda_graph_pad_size: if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
print('###input_tokens', input_tokens)
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
...@@ -916,12 +917,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -916,12 +917,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for idx in range(3): for idx in range(3):
mrope_input_positions[idx].extend( mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size)) itertools.repeat(0, cuda_graph_pad_size))
print('###mrope_input_positions', mrope_input_positions)
input_positions_tensor = async_tensor_h2d(mrope_input_positions, input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.long, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
else: else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
print('###input_positions', input_positions)
input_positions_tensor = async_tensor_h2d(input_positions, input_positions_tensor = async_tensor_h2d(input_positions,
torch.long, torch.long,
self.runner.device, self.runner.device,
...@@ -929,6 +932,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -929,6 +932,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Sequence and query lengths. # Sequence and query lengths.
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
print('###seq_lens', seq_lens)
# Attention metadata. # Attention metadata.
attn_metadata = self.attn_metadata_builder.build( attn_metadata = self.attn_metadata_builder.build(
...@@ -987,7 +991,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -987,7 +991,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
] ]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls( ret = self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor, token_types=token_types_tensor,
...@@ -1001,6 +1005,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -1001,6 +1005,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping, prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests) prompt_adapter_requests=prompt_adapter_requests)
print('###model_input', ret)
return ret
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
......
...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer ...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import is_zero_no_thread from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -301,7 +301,10 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -301,7 +301,10 @@ class ZeroOverheadEngine(LLMEngine):
) = self.scheduler[virtual_engine].schedule() ) = self.scheduler[virtual_engine].schedule()
if self.last_record is not None: if self.last_record is not None:
last_sampler = self.last_record[1] last_sampler = self.last_record[1]
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True) if spec_step == SpecStepKind.KIND_DEFAULT:
self.async_d2h = last_sampler.sampled_token_ids_tensor.to('cpu', non_blocking=True)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_d2h = last_sampler.to('cpu', non_blocking=True)
self.async_event.record() self.async_event.record()
self.q_recorder.put(self.last_record) self.q_recorder.put(self.last_record)
else: else:
...@@ -332,13 +335,18 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -332,13 +335,18 @@ class ZeroOverheadEngine(LLMEngine):
outputs = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
if len(outputs) == 1: for output in outputs:
self._advance_to_next_step( self._advance_to_next_step(
outputs[0], seq_group_metadata_list, output, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy scheduler_outputs.scheduled_seq_groups = [item for item in scheduler_outputs.scheduled_seq_groups] #deep copy
last_sampler = get_last_sampler() last_sampler = None
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs] spec_step = get_spec_step()
if spec_step == SpecStepKind.KIND_DEFAULT:
last_sampler = get_last_sampler()
elif spec_step == SpecStepKind.SCORE_DECODE:
last_sampler, _ = get_accepted_token_ids()
self.last_record = [outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step]
except Exception as e: except Exception as e:
print(f"thread_zero_overhead error : {e}") print(f"thread_zero_overhead error : {e}")
...@@ -357,13 +365,19 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -357,13 +365,19 @@ class ZeroOverheadEngine(LLMEngine):
virtual_engine = 0 virtual_engine = 0
ctx = self.scheduler_contexts[virtual_engine] ctx = self.scheduler_contexts[virtual_engine]
ctx.request_outputs.clear() ctx.request_outputs.clear()
outputs, last_sampler, seq_group_metadata_list, scheduler_outputs = recode_output outputs, last_sampler, seq_group_metadata_list, scheduler_outputs, spec_step = recode_output
ctx.seq_group_metadata_list = seq_group_metadata_list ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize() if spec_step == SpecStepKind.KIND_DEFAULT:
self._fix_last_step( self.async_event.synchronize()
outputs, last_sampler, seq_group_metadata_list, self._fix_last_step(
scheduler_outputs.scheduled_seq_groups) outputs, last_sampler, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
elif spec_step == SpecStepKind.SCORE_DECODE:
self.async_event.synchronize()
self._fix_spec_decode_steps(
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# is_first_step_output is True only when the num_steps of all # is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1, # the sequences are 1. When the num_steps > 1,
...@@ -430,6 +444,33 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -430,6 +444,33 @@ class ZeroOverheadEngine(LLMEngine):
sample.output_token = token_id sample.output_token = token_id
seq.fix_last_token_id(sample.output_token) seq.fix_last_token_id(sample.output_token)
break break
def _fix_spec_decode_steps(
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]):
sample_out_list = self.async_d2h.tolist()
group_idx = 0
for seq_group_metadata, accept_token_ids, scheduled_seq_group in \
zip(seq_group_metadata_list, sample_out_list, scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
if seq_group.is_finished():
group_idx += 1
continue
if seq_group_metadata.do_sample:
assert len(seq_group.seqs) == 1
seq : ZeroOverheadSequence = seq_group.seqs[0]
remove_count = 0
for token_id in accept_token_ids:
if token_id == -1:
remove_count += 1
else:
seq.fix_last_token_id(token_id)
seq.remove_last_place_holder(remove_count)
group_idx += 1
def no_thread_step(self): def no_thread_step(self):
virtual_engine = 0 virtual_engine = 0
......
...@@ -11,7 +11,65 @@ from vllm.sequence import SequenceGroupMetadata ...@@ -11,7 +11,65 @@ from vllm.sequence import SequenceGroupMetadata
from vllm.utils import async_tensor_h2d, flatten_2d_lists from vllm.utils import async_tensor_h2d, flatten_2d_lists
from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPU, ModelInputForGPUBuilder
from vllm.zero_overhead.sampler import get_last_sampler from vllm.zero_overhead.sampler import get_last_sampler
from vllm.zero_overhead.update_input import UpdateInputTokens from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_proposal_token_ids, get_spec_last_step, get_spec_step
import triton
import triton.language as tl
@triton.jit
def _update_input_tokens(
accepted_req_ids,
accepted_req_ids_len,
accepted_token_ids,
accepted_token_len,
chidren_req_ids,
chidren_req_ids_len,
input_tokens,
input_tokens_len,
input_positions,
seq_lens,
seq_lens_meta,
seq_lens_tensor,
slot_mapping,
seq_start_loc,
context_lens_tensor,
):
chidren_req_ids_ = tl.load(chidren_req_ids + tl.arange(0, chidren_req_ids_len))
accepted_req_ids_ = tl.load(accepted_req_ids + tl.arange(0, chidren_req_ids_len))
for seq_id_idx in range(chidren_req_ids_len / 2):
seq_id = chidren_req_ids_[2 * seq_id_idx]
for i in range(accepted_req_ids_len):
if seq_id == accepted_req_ids_[i]:
accepted_token_ids_ = tl.load(accepted_token_ids + tl.arange(i * accepted_token_len, tl.arange(0, accepted_token_len)))
accepted_token_counter = 0
for j in range(accepted_token_len):
if accepted_token_ids_[j] == -1:
break
accepted_token_counter += 1
if accepted_token_counter == accepted_token_len:
tl.store(input_tokens + seq_id_idx * 2 + tl.arange(0, 2), accepted_token_ids_[-2:])
else:
tl.store(input_tokens + seq_id_idx * 2, 0)
tl.store(input_tokens + seq_id_idx * 2 + 1, accepted_token_ids_[accepted_token_counter - 1])
input_pos = tl.load(input_positions + seq_id_idx * 2 + tl.arange(0, 2))
input_pos[0] = 0
input_pos[1] = input_pos[1] - (accepted_req_ids_len - accepted_token_counter)
tl.store(input_positions + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(context_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = -1
tl.store(slot_mapping + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
input_pos[0] = 1
input_pos[1] = input_pos[1] + 1
tl.store(seq_lens + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_meta + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
tl.store(seq_lens_tensor + seq_id_idx * 2 + tl.arange(0, 2), input_pos)
seq_lens_ = tl.load(seq_lens + tl.arange(0, input_tokens_len))
seq_start_loc_ = tl.zero_like(seq_start_loc)
for i in range(input_tokens_len):
seq_start_loc_[i + 1] = seq_start_loc_[i] + seq_lens_[i]
tl.store(seq_start_loc + tl.arange(0, input_tokens_len + 1), seq_start_loc_)
class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
...@@ -34,22 +92,80 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder): ...@@ -34,22 +92,80 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
def build(self) -> ModelInputForGPU: def build(self) -> ModelInputForGPU:
model_input = super().build() model_input = super().build()
print('###model_input', model_input)
last_sampler = get_last_sampler() last_sampler = get_last_sampler()
spec_step = get_spec_step()
last_step = get_spec_last_step()
if last_sampler is not None: if last_sampler is not None:
if spec_step == SpecStepKind.KIND_DEFAULT:
update_indices = []
select_indices = []
for i, seq_id in enumerate(self.req_ids):
for j, seq_id_ in enumerate(last_sampler.seq_ids):
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
if len(select_indices) > 0:
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0]
if spec_step == SpecStepKind.OTHER_PROPOSAL:
if last_step == SpecStepKind.OTHER_PROPOSAL: # copy last sampled token ids to input tokens directly.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if last_step == SpecStepKind.FIRST_PROPOSAL: # TODO: ajust input tokens number to 1 per request.
update_indices = [i for i in range(len(self.req_ids))]
update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[update_indices, 0]
if spec_step == SpecStepKind.SCORE_DECODE:
proposal_token_ids = get_proposal_token_ids()
shape = proposal_token_ids.shape
batch_size = shape[0]
proposal_len = shape[1]
update_indices = [] update_indices = []
select_indices = [] for i in range(batch_size):
for i, seq_id in enumerate(self.req_ids): for j in range(proposal_len):
for j, seq_id_ in enumerate(last_sampler.seq_ids): update_indices.append(i * (proposal_len + 1) + j + 1)
if seq_id == seq_id_:
select_indices.append(j)
update_indices.append(i)
break
select_indices = async_tensor_h2d(select_indices, torch.long,
self.runner.device,
self.runner.pin_memory)
update_indices = async_tensor_h2d(update_indices, torch.long, update_indices = async_tensor_h2d(update_indices, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
if len(select_indices) > 0: model_input.input_tokens[update_indices] = proposal_token_ids.view(-1)
model_input.input_tokens[update_indices] = last_sampler.sampled_token_ids_tensor[select_indices, 0] if spec_step == SpecStepKind.FIRST_PROPOSAL:
if last_step == SpecStepKind.PREFILL:# TODO: when last step is prefill, just update the input ids for last seqence_id onely.
pass
if last_step == SpecStepKind.SCORE_DECODE:# TODO: when last step is score decode, fix input ids、seq_lens、input_positions use accepte token ids
accept_token_ids, accept_seq_ids = get_accepted_token_ids()
chidren_req_ids = async_tensor_h2d(self.req_ids, torch.long,
self.runner.device,
self.runner.pin_memory)
grid = [1, 1, 1]
_update_input_tokens[grid](
accept_seq_ids, accept_seq_ids.shape[0],
accept_token_ids, accept_token_ids.shape[1],
chidren_req_ids, chidren_req_ids.shape[0],
model_input.input_tokens, model_input.input_tokens.shape[0],
model_input.input_positions,
model_input.seq_lens,
model_input.attn_metadata.seq_lens_tensor,
model_input.attn_metadata.seq_lens,
model_input.attn_metadata.slot_mapping,
model_input.attn_metadata.seq_start_loc,
model_input.attn_metadata.context_lens_tensor,
)
print('###zero_model_input', model_input)
return model_input return model_input
...@@ -359,6 +359,7 @@ def _sample_with_torch( ...@@ -359,6 +359,7 @@ def _sample_with_torch(
sampled_token_ids_tensor[long_sample_indices] = \ sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long) multinomial_samples[sampling_type].to(torch.long)
print('###sampled_token_ids', last_sampler.sampled_token_ids_tensor)
# Encapsulate arguments for computing Pythonized sampler # Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise. # results, whether deferred or otherwise.
maybe_deferred_args = SampleResultArgsType( maybe_deferred_args = SampleResultArgsType(
......
...@@ -19,6 +19,11 @@ class ZeroOverheadSequence(Sequence): ...@@ -19,6 +19,11 @@ class ZeroOverheadSequence(Sequence):
self.data._cached_all_token_ids[effect_offset] = token_id self.data._cached_all_token_ids[effect_offset] = token_id
self.effective_output_len += 1 self.effective_output_len += 1
def remove_last_place_holder(self, count):
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
self.data._num_computed_tokens -= count
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]: def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
return self.data.output_token_ids[:self.effective_output_len] return self.data.output_token_ids[:self.effective_output_len]
......
...@@ -15,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, ...@@ -15,6 +15,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import get_proposal_lens_list, record_proposal_token_ids
SeqId = int SeqId = int
TargetSeqId = int TargetSeqId = int
...@@ -48,8 +49,9 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer): ...@@ -48,8 +49,9 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
which sequences were ignored during scoring. which sequences were ignored during scoring.
""" """
proposal_lens_list = np.zeros(proposals.proposal_lens.shape, dtype=int).tolist() #zero_overhead todo fix proposal_lens_list = get_proposal_lens_list()
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() record_proposal_token_ids(proposals.proposal_token_ids)
proposal_token_ids_list = np.zeros(proposals.proposal_token_ids.shape, dtype=int).tolist() # place holder tokens
# Filter the list to ignore invalid proposals. # Filter the list to ignore invalid proposals.
proposal_token_ids_list_without_skips = [ proposal_token_ids_list_without_skips = [
...@@ -64,14 +66,11 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer): ...@@ -64,14 +66,11 @@ class ZeroOverheadBatchExpansionTop1Scorer(BatchExpansionTop1Scorer):
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
) )
#print('###execute_model_req', execute_model_req)
#print('###target_seq_group_metadata_list', target_seq_group_metadata_list)
target_sampler_output = self._scorer_worker.execute_model( target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone( execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list)) seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
#print('###target_sampler_output', target_sampler_output)
if not non_spec_indices: if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled # All sequence groups in batch have spec decoding enabled
return self._contract_batch_all_spec( return self._contract_batch_all_spec(
......
...@@ -11,6 +11,7 @@ from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, ...@@ -11,6 +11,7 @@ from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer from vllm.zero_overhead.spec_decode.top1_proproser import ZeroOverheadTop1Proposer
from vllm.zero_overhead.utils import SpecStepKind, set_spec_step
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
...@@ -45,7 +46,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker): ...@@ -45,7 +46,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
For multi step worker, this indicator shall be True. For multi step worker, this indicator shall be True.
""" """
print('###execute_model_req', execute_model_req)
self._raise_if_unsupported(execute_model_req) self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token. # Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the # Perform a forward pass on the expanded batch and filter the
...@@ -53,7 +53,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker): ...@@ -53,7 +53,6 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
expanded_request, indices_of_seq_with_bonus_tokens =\ expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request( self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step) execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times. # Run model sample_len times.
model_outputs: List[SamplerOutput] = [] model_outputs: List[SamplerOutput] = []
if current_platform.is_cuda_alike() and isinstance( if current_platform.is_cuda_alike() and isinstance(
...@@ -72,20 +71,20 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker): ...@@ -72,20 +71,20 @@ class ZeroOverheadMultiStepWorker(MultiStepWorker):
# TODO: Remove this branch once DraftModelRunner supports TP>1 # TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's # and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..) # supports_gpu_multi_step(..)
set_spec_step(SpecStepKind.FIRST_PROPOSAL)
for _ in range(sample_len): for _ in range(sample_len):
print('###self.worker.execute_model', sample_len)
print('###expanded_request', expanded_request)
model_output: List[SamplerOutput] = self.worker.execute_model( model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request) execute_model_req=expanded_request)
assert (len(model_output) == 1 assert (len(model_output) == 1
), "composing multistep workers not supported" ), "composing multistep workers not supported"
model_output = model_output[0] model_output = model_output[0]
print('###model_output', model_output) set_spec_step(SpecStepKind.OTHER_PROPOSAL)
self._append_new_tokens( self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list, model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens) indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output) model_outputs.append(model_output)
set_spec_step(SpecStepKind.SCORE_DECODE)
filtered_model_outputs = self._filter_model_output_zero_overhead( filtered_model_outputs = self._filter_model_output_zero_overhead(
model_outputs, indices_of_seq_with_bonus_tokens) model_outputs, indices_of_seq_with_bonus_tokens)
......
...@@ -27,8 +27,9 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, ...@@ -27,8 +27,9 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
get_all_seq_ids_and_request_ids, Logits) get_all_seq_ids_and_request_ids, Logits)
from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer from vllm.spec_decode.batch_expansion import BatchExpansionTreeStyleScorer
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker, prepare_prefill_hidden_states
from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer from vllm.zero_overhead.spec_decode.batch_expansion import ZeroOverheadBatchExpansionTop1Scorer
from vllm.zero_overhead.utils import SpecStepKind, record_accepted_token_ids, set_spec_step
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
...@@ -49,7 +50,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output, ...@@ -49,7 +50,7 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs, get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.utils import resolve_obj_by_qualname from vllm.utils import async_tensor_h2d, resolve_obj_by_qualname
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
...@@ -113,6 +114,90 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker): ...@@ -113,6 +114,90 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self._configure_model_sampler_for_spec_decode() self._configure_model_sampler_for_spec_decode()
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a single generation step without any speculation. The input is
sent to the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
if self.tree_decoding and self.kvcache_slot_to_be_moved is not None:
execute_model_req.kvcache_slot_to_be_moved = self.kvcache_slot_to_be_moved
self.kvcache_slot_to_be_moved = None
set_spec_step(SpecStepKind.PREFILL)
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
# Store hidden states from target model execution, BxD.
hidden_states = sampler_output.hidden_states
if hidden_states is not None:
# Only decodes and prefill terminal chunks need a hidden state.
seq_group_meta_with_hidden = [
sg for sg in execute_model_req.seq_group_metadata_list
if sg.do_sample
]
if any(seq.is_prompt for seq in seq_group_meta_with_hidden):
# Drop hidden_states with no prediction (eg non-terminal chunks)
hidden_states = hidden_states[
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
# if not skip_proposer:
# if self.previous_hidden_states is None and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states = HiddenStates(
# hidden_states, seq_group_meta_with_hidden)
# elif self.previous_hidden_states and len(
# seq_group_meta_with_hidden):
# self.previous_hidden_states.update(hidden_states,
# seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
logits = sampler_output.logits
if logits is not None:
if self.previous_logits is None:
self.previous_logits = Logits(
logits, execute_model_req.seq_group_metadata_list)
else:
self.previous_logits.update(
logits, execute_model_req.seq_group_metadata_list)
if not skip_proposer:
# We prepare the prefill hidden states here so that there no
# additional complexity in worker for spec_decode vs non_spec_decode
# flow and execute_model doesn't need additional modifications.
execute_model_req.previous_hidden_states = \
prepare_prefill_hidden_states(
sampler_output.prefill_hidden_states)
for i in range(self._num_spec_prefill_steps):
execute_model_req.spec_step_idx = i
self.proposer_worker.execute_model(execute_model_req)
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
execute_model_req=execute_model_req, sampler_output=sampler_output)
if self._disable_logprobs else
[sampler_output])
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
sampler_output.sampled_token_probs = None
sampler_output.sampled_token_ids = None
sampler_output.logprobs = None
return sampler_output_to_return
@nvtx_range("spec_decode_worker._verify_tokens") @nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens( def _verify_tokens(
self, self,
...@@ -338,7 +423,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker): ...@@ -338,7 +423,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list)
# Serialize tensor to CPU Python list. # Serialize tensor to CPU Python list.
#print('###accepted_token_ids_by_step', accepted_token_ids_by_step) #accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
record_accepted_token_ids(accepted_token_ids, seq_ids)
# Construct the output on a per-step, per-sequence basis. # Construct the output on a per-step, per-sequence basis.
# Non-terminal prefill chunks will end up here as rows with just -1s # Non-terminal prefill chunks will end up here as rows with just -1s
...@@ -428,8 +514,7 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker): ...@@ -428,8 +514,7 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
num_logprobs = num_logprobs_per_seq[sequence_index] num_logprobs = num_logprobs_per_seq[sequence_index]
step_output_token_ids.append( step_output_token_ids.append(
create_sequence_group_output( create_sequence_group_output(
token_id=accepted_token_ids_by_step[step_index] token_id = 0,
[sequence_index],
token_id_logprob_rank=accepted_token_id_ranks_by_step[ token_id_logprob_rank=accepted_token_id_ranks_by_step[
step_index][sequence_index], step_index][sequence_index],
token_id_logprob=accepted_token_id_logprobs_by_step[ token_id_logprob=accepted_token_id_logprobs_by_step[
...@@ -460,9 +545,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker): ...@@ -460,9 +545,8 @@ class ZeroOverheadSpecDecodeWorker(SpecDecodeWorker):
self._maybe_log_stage_times(*stage_times) self._maybe_log_stage_times(*stage_times)
# First `n_prefills` entries will contain prefills SamplerOutput when # First `n_prefills` entries will contain prefills SamplerOutput when
# chunked prefill is enabled, the rest is decodes in multi-step format. # chunked prefill is enabled, the rest is decodes in multi-step format.
print('###sampler_output_list', sampler_output_list)
return sampler_output_list return sampler_output_list
def _track_sequences_with_bonus_tokens( def _track_sequences_with_bonus_tokens(
self, seq_ids: List[int], self, seq_ids: List[int],
......
...@@ -11,6 +11,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase ...@@ -11,6 +11,7 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import sampler_output_to_torch
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.zero_overhead.utils import record_proposal_lens_list
class ZeroOverheadTop1Proposer(Top1Proposer): class ZeroOverheadTop1Proposer(Top1Proposer):
...@@ -48,13 +49,14 @@ class ZeroOverheadTop1Proposer(Top1Proposer): ...@@ -48,13 +49,14 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( proposal_tokens, proposal_probs, *_ = sampler_output_to_torch(
sampler_output, sampler_transposed) sampler_output, sampler_transposed)
proposal_lens_list = [0 for i in range(batch_size)]
for indices in nonzero_proposal_len_indices:
proposal_lens_list[indices] = proposal_len
record_proposal_lens_list(proposal_lens_list)
nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32, nonzero_proposal_len_indices = async_tensor_h2d(nonzero_proposal_len_indices, torch.int32,
self._device, self._device,
True) True)
proposal_len = [proposal_len for i in range(batch_size)]
proposal_len = async_tensor_h2d(proposal_len, torch.long,
self._device,
True)
# Now, reformat the output GPU tensors such that each sequence has # Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1] # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
...@@ -74,10 +76,9 @@ class ZeroOverheadTop1Proposer(Top1Proposer): ...@@ -74,10 +76,9 @@ class ZeroOverheadTop1Proposer(Top1Proposer):
entire_proposal_tokens, entire_proposal_tokens,
entire_proposal_probs, entire_proposal_probs,
) )
proposal_lens_tensor = torch.zeros(batch_size, proposal_lens_tensor = async_tensor_h2d(proposal_lens_list, torch.long,
dtype=torch.long, self._device,
device=self._device) True)
proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len
return proposal_tokens, proposal_probs, proposal_lens_tensor return proposal_tokens, proposal_probs, proposal_lens_tensor
\ No newline at end of file
from enum import Enum
import os import os
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
...@@ -9,4 +10,55 @@ def is_zero_overhead(): ...@@ -9,4 +10,55 @@ def is_zero_overhead():
return zero_overhead return zero_overhead
def is_zero_no_thread(): def is_zero_no_thread():
return zero_no_thread and zero_overhead return zero_no_thread and zero_overhead
\ No newline at end of file
class SpecStepKind(Enum):
KIND_DEFAULT = 0
PREFILL = 1
FIRST_PROPOSAL = 2
OTHER_PROPOSAL = 3
SCORE_DECODE = 4
class ZeroOverheadSpecContext():
def __init__(self):
self.step_kind = SpecStepKind.KIND_DEFAULT
self.last_step = SpecStepKind.KIND_DEFAULT
self.proposal_lens_list = None
self.proposal_token_ids = None
self.accepted_token_ids = None
self.accepted_seq_ids = None
spec_context = ZeroOverheadSpecContext()
def set_spec_step(_step):
global spec_context
spec_context.last_step = spec_context.step_kind
spec_context.step_kind = _step
def get_spec_step():
return spec_context.step_kind
def get_spec_last_step():
return spec_context.last_step
def record_proposal_lens_list(list):
global spec_context
spec_context.proposal_lens_list = list
def get_proposal_lens_list():
return spec_context.proposal_lens_list
def record_proposal_token_ids(tensor):
global spec_context
spec_context.proposal_token_ids = tensor
def get_proposal_token_ids():
return spec_context.proposal_token_ids
def record_accepted_token_ids(tensor, seq_ids):
global spec_context
spec_context.accepted_token_ids = tensor
spec_context.accepted_seq_ids = seq_ids
def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment