Commit 0be169ad authored by lizhigong's avatar lizhigong
Browse files

debug and fix some error about outputs

parent 18b9f67c
......@@ -1244,255 +1244,8 @@ class LLMEngine:
return None
def fix_process_model_output(self,
ctx_output_queue,
ctx_request_outputs,
ctx_multi_step_stream_outputs,
request_id: Optional[str] = None) -> None:
now = time.time()
if len(ctx_output_queue) == 0:
return None
# Get pending async postprocessor
if request_id:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output, skip) = ctx_output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output,
skip) = ctx_output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if self.scheduler_config.is_multi_step:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, len(seq_group_metadata_list))
elif self.speculative_config:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills = sum(sg.is_prompt
for sg in seq_group_metadata_list)
prefills, decodes = outputs[:num_prefills], outputs[
num_prefills:]
outputs_by_sequence_group = create_output_by_sequence_group(
decodes,
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
outputs_by_sequence_group = [p.outputs for p in prefills
] + outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
elif len(outputs) == 1:
outputs_by_sequence_group = outputs
else:
return None
# Determine the requests we need to operate on
if request_id:
indices = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
if seq_group_meta.request_id == request_id:
assert i not in skip # Cannot be called twice
indices.append(i)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if not indices:
return
else:
indices = range(len(seq_group_metadata_list)) # type: ignore
finished_before: List[int] = []
finished_now: List[int] = []
empty_seq_indices: List[int] = []
for i in indices:
if i in skip:
continue
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished():
finished_before.append(i)
continue
output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if not is_async:
#print("hello")
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size or 0)
if outputs:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.runner_type == "pooling":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
if seq_group.is_finished():
finished_now.append(i)
# Generate outputs for the requests that finished this iteration
for i in finished_now:
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx_request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if request_id:
assert len(indices) == 1
skip.append(indices[0])
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For multi-step without streaming, don't create outputs each iteration
if not is_last_step and not ctx_multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
# Create the outputs
for i in indices:
if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx_request_outputs.append(request_output)
# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx_multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx_request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given)
if (ctx_request_outputs
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if is_async:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before,
skip)
# Tracing
self.do_tracing(scheduler_outputs, finished_before)
return None
def _fix_last_step(
self, ctx_output_queue,ctx_request_outputs,
ctx_multi_step_stream_outputs,output: List[SamplerOutput],
self, output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
......@@ -1514,9 +1267,7 @@ class LLMEngine:
for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id:
sample.output_token = token_id[0]
seq.data._effective_length+=1
seq.fix_last_token_id(sample.output_token)
self.fix_process_model_output(ctx_output_queue,ctx_request_outputs,ctx_multi_step_stream_outputs)
break
def _advance_to_next_step(
......@@ -1612,9 +1363,9 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
# if allow_async_output_proc:
# execute_model_req.async_callback = self.async_callbacks[
# virtual_engine]
#profile.ProfRangeAutoPush('model_executor')
outputs = self.model_executor.execute_model(
......@@ -1637,44 +1388,32 @@ class LLMEngine:
ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize()
self._fix_last_step(
ctx.output_queue,
ctx.request_outputs,
ctx.multi_step_stream_outputs,
outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
allow_async_output_proc = True
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Add results to the output_queue
ctx.append_output(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, outputs)
# Log stats.
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
# Tracing
self.do_tracing(scheduler_outputs)
#profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests():
......@@ -1820,8 +1559,7 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
if not self.zero_overhead:
execute_model_req.async_callback = self.async_callbacks[
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
#profile.ProfRangeAutoPush('model_executor')
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, List, Optional, Tuple
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer
import os
class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop
......@@ -21,7 +22,6 @@ class StopChecker:
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
......@@ -44,104 +44,53 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if self.zero_overhead:
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
return
#new char count的 暂时未修改逻辑
if seq.get_output_len(self.zero_overhead) < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() >= sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else:
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id(self.zero_overhead) == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id(self.zero_overhead)
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len(self.zero_overhead) > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len(self.zero_overhead) == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod
def check_stop_strings(
......
......@@ -7,6 +7,7 @@ from array import array
from collections import defaultdict
from dataclasses import dataclass, field
from functools import reduce
import os
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
......@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True
_effective_length:int =0
_effective_length: int = 0
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
......@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct,
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
def fix_effective_token_id(self, token_id: int,):
effect_offset = self._effective_length - len(self.output_token_ids)
if effect_offset < 0:
self._output_token_ids[effect_offset] = token_id
self._new_appended_tokens[effect_offset] = token_id
self._cached_all_token_ids[effect_offset] = token_id
self._effective_length += 1
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
def zero_overhead_get_len(self) -> int:
return self._effective_length + len(self._prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self._prompt_token_ids)
def get_output_len(self) -> int:
return len(self._output_token_ids)
def zero_overhead_get_output_len(self) -> Tuple[int, ...]:
return self._effective_length
def zero_overhead_get_output_len(self) -> int:
return self._effective_length
def get_token_ids(self) -> List[int]:
return self._cached_all_token_ids
......@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int:
if not self._output_token_ids:
return self._prompt_token_ids[-1]
return self._output_token_ids[-1]
def zero_overhead_get_last_token_id(self) -> int:
if self._effective_length==0:
if self._effective_length == 0:
return self._prompt_token_ids[-1]
return self._output_token_ids[self._effective_length-1]
return self._output_token_ids[self._effective_length - 1]
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids
def zero_overhead_get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids[:self._effective_length]
def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids
......@@ -469,6 +485,7 @@ class Sequence:
self.read_offset = 0
# Input + output tokens
self.tokens: Optional[List[str]] = None
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
@property
def n_blocks(self) -> int:
......@@ -535,9 +552,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
return self.get_output_token_ids(self.zero_overhead)
output_len = self.get_output_len()
output_len = self.get_output_len(self.zero_overhead)
# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
......@@ -547,11 +564,16 @@ class Sequence:
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]
if self.zero_overhead:
return self.data._cached_all_token_ids[self.data._effective_length - 1]
else:
return self.data._cached_all_token_ids[-1]
if num_new_tokens == 0:
return []
if self.zero_overhead:
return self.data._cached_all_token_ids[-num_new_tokens : self.data._effective_length]
return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int:
......@@ -591,34 +613,35 @@ class Sequence:
self.data.append_token_id(token_id, logprobs[token_id].logprob)
def fix_last_token_id(self, token_id: int) -> None:
self.data._output_token_ids[-1] = token_id
self.data._new_appended_tokens[-1] = token_id
self.data._cached_all_token_ids[-1] = token_id
self.data.fix_effective_token_id(token_id)
def get_len(self) -> int:
def get_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_len()
return self.data.get_len()
def zero_overhead_get_len(self) -> int:
return self.data.zero_overhead_get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int:
def get_output_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_output_len()
return self.data.get_output_len()
def zero_overhead_get_output_len(self) -> int:
return self.data.zero_overhead_get_output_len()
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
def get_last_token_id(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_last_token_id()
return self.data.get_last_token_id()
def zero_overhead_get_last_token_id(self) -> int:
return self.data.zero_overhead_get_last_token_id()
def get_output_token_ids(self) -> Tuple[int, ...]:
def get_output_token_ids(self, zero_overhead = False) -> Tuple[int, ...]:
if zero_overhead:
return self.data.zero_overhead_get_output_token_ids()
return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float:
......
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Dict, List, Optional
import os
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup)
......@@ -108,11 +109,11 @@ class Detokenizer:
Returns:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
all_input_ids = seq.get_token_ids()
if self.zero_overhead:
eff_length=seq.get_prompt_len()+seq.data._effective_length
all_input_ids = seq.get_token_ids()[:eff_length]
print(f'{all_input_ids=}')
eff_length = seq.get_prompt_len() + seq.data._effective_length
all_input_ids = seq.get_token_ids()[ : eff_length]
token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment