Commit 0be169ad authored by lizhigong's avatar lizhigong
Browse files

debug and fix some error about outputs

parent 18b9f67c
...@@ -1244,255 +1244,8 @@ class LLMEngine: ...@@ -1244,255 +1244,8 @@ class LLMEngine:
return None return None
def fix_process_model_output(self,
ctx_output_queue,
ctx_request_outputs,
ctx_multi_step_stream_outputs,
request_id: Optional[str] = None) -> None:
now = time.time()
if len(ctx_output_queue) == 0:
return None
# Get pending async postprocessor
if request_id:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output, skip) = ctx_output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output,
skip) = ctx_output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if self.scheduler_config.is_multi_step:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, len(seq_group_metadata_list))
elif self.speculative_config:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills = sum(sg.is_prompt
for sg in seq_group_metadata_list)
prefills, decodes = outputs[:num_prefills], outputs[
num_prefills:]
outputs_by_sequence_group = create_output_by_sequence_group(
decodes,
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
outputs_by_sequence_group = [p.outputs for p in prefills
] + outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
elif len(outputs) == 1:
outputs_by_sequence_group = outputs
else:
return None
# Determine the requests we need to operate on
if request_id:
indices = []
for i, seq_group_meta in enumerate(seq_group_metadata_list):
if seq_group_meta.request_id == request_id:
assert i not in skip # Cannot be called twice
indices.append(i)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if not indices:
return
else:
indices = range(len(seq_group_metadata_list)) # type: ignore
finished_before: List[int] = []
finished_now: List[int] = []
empty_seq_indices: List[int] = []
for i in indices:
if i in skip:
continue
seq_group_meta = seq_group_metadata_list[i]
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group: SequenceGroup = scheduled_seq_group.seq_group
if seq_group.is_finished():
finished_before.append(i)
continue
output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
output = [outputs_by_sequence_group[0][i]]
# tree style speculative decoding may generate empty output in first step
if self.tree_decoding and outputs and isinstance(output[0], CompletionSequenceGroupOutput):
samples = [o.samples[0] for o in output]
valid_samples = [
sample for sample in samples
if sample.output_token != VLLM_INVALID_TOKEN_ID
]
if len(valid_samples) == 0:
empty_seq_indices.append(i)
continue
if not is_async:
#print("hello")
if self.scheduler_config.is_multi_step:
# Updates happen only if the sequence is prefill
self._update_num_computed_tokens_for_multi_step_prefill(
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size or 0)
if outputs:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.runner_type == "pooling":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
if seq_group_meta.do_sample:
self.output_processor.process_outputs(
seq_group, output, is_async)
if seq_group.is_finished():
finished_now.append(i)
# Generate outputs for the requests that finished this iteration
for i in finished_now:
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx_request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if request_id:
assert len(indices) == 1
skip.append(indices[0])
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
# Free currently finished requests
if finished_now:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For multi-step without streaming, don't create outputs each iteration
if not is_last_step and not ctx_multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
# Create the outputs
for i in indices:
if i in skip or i in finished_before or i in finished_now or i in empty_seq_indices:
continue # Avoids double processing
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if not seq_group.is_prefill():
seq_group.set_last_token_time(now)
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs)
if request_output:
ctx_request_outputs.append(request_output)
# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx_multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
return
for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue
request_output = RequestOutputFactory.create(
seq_group,
self.seq_id_to_seq_group,
use_cache=self.use_cached_outputs,
)
if request_output:
ctx_request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given)
if (ctx_request_outputs
and self.process_request_outputs_callback is not None):
self.process_request_outputs_callback(ctx_request_outputs)
ctx_request_outputs.clear()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if is_async:
# Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before,
skip)
# Tracing
self.do_tracing(scheduler_outputs, finished_before)
return None
def _fix_last_step( def _fix_last_step(
self, ctx_output_queue,ctx_request_outputs, self, output: List[SamplerOutput],
ctx_multi_step_stream_outputs,output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
...@@ -1514,9 +1267,7 @@ class LLMEngine: ...@@ -1514,9 +1267,7 @@ class LLMEngine:
for token_id, seq_id in zip(sample_out_list, sample_out_ids): for token_id, seq_id in zip(sample_out_list, sample_out_ids):
if seq.seq_id == seq_id: if seq.seq_id == seq_id:
sample.output_token = token_id[0] sample.output_token = token_id[0]
seq.data._effective_length+=1
seq.fix_last_token_id(sample.output_token) seq.fix_last_token_id(sample.output_token)
self.fix_process_model_output(ctx_output_queue,ctx_request_outputs,ctx_multi_step_stream_outputs)
break break
def _advance_to_next_step( def _advance_to_next_step(
...@@ -1612,9 +1363,9 @@ class LLMEngine: ...@@ -1612,9 +1363,9 @@ class LLMEngine:
last_sampled_token_ids=last_sampled_token_ids, last_sampled_token_ids=last_sampled_token_ids,
last_outputs_ids = last_outputs_ids, last_outputs_ids = last_outputs_ids,
last_outputs_sample = last_outputs_tensor) last_outputs_sample = last_outputs_tensor)
if allow_async_output_proc: # if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[ # execute_model_req.async_callback = self.async_callbacks[
virtual_engine] # virtual_engine]
#profile.ProfRangeAutoPush('model_executor') #profile.ProfRangeAutoPush('model_executor')
outputs = self.model_executor.execute_model( outputs = self.model_executor.execute_model(
...@@ -1637,44 +1388,32 @@ class LLMEngine: ...@@ -1637,44 +1388,32 @@ class LLMEngine:
ctx.scheduler_outputs = scheduler_outputs ctx.scheduler_outputs = scheduler_outputs
self.async_event.synchronize() self.async_event.synchronize()
self._fix_last_step( self._fix_last_step(
ctx.output_queue,
ctx.request_outputs,
ctx.multi_step_stream_outputs,
outputs, seq_group_metadata_list, outputs, seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups) scheduler_outputs.scheduled_seq_groups)
allow_async_output_proc = True # is_first_step_output is True only when the num_steps of all
if not self._has_remaining_steps(seq_group_metadata_list): # the sequences are 1. When the num_steps > 1,
# clear the cache if we have finished all the steps. # multi_step_model_runner does the first-step output append.
if self.scheduler_config.is_multi_step: is_first_step_output: bool = False if not seq_group_metadata_list \
self.cached_scheduler_outputs[0] = SchedulerOutputState() else seq_group_metadata_list[0].state.num_steps == 1
# is_first_step_output is True only when the num_steps of all # Add results to the output_queue
# the sequences are 1. When the num_steps > 1, ctx.append_output(outputs=outputs,
# multi_step_model_runner does the first-step output append. seq_group_metadata_list=seq_group_metadata_list,
is_first_step_output: bool = False if not seq_group_metadata_list \ scheduler_outputs=scheduler_outputs,
else seq_group_metadata_list[0].state.num_steps == 1 is_async=True,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Add results to the output_queue # Check if need to run the usual non-async path
ctx.append_output(outputs=outputs, #if not allow_async_output_proc:
seq_group_metadata_list=seq_group_metadata_list, self._process_model_outputs(ctx=ctx)
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
# Check if need to run the usual non-async path # Log stats.
if not allow_async_output_proc: self.do_log_stats(scheduler_outputs, outputs)
self._process_model_outputs(ctx=ctx)
# Log stats. # Tracing
self.do_log_stats(scheduler_outputs, outputs) self.do_tracing(scheduler_outputs)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
#profile.ProfRangeAutoPush('has_unfinish') #profile.ProfRangeAutoPush('has_unfinish')
if not self.has_unfinished_requests(): if not self.has_unfinished_requests():
...@@ -1820,8 +1559,7 @@ class LLMEngine: ...@@ -1820,8 +1559,7 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input. # to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids) last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc: if allow_async_output_proc:
if not self.zero_overhead: execute_model_req.async_callback = self.async_callbacks[
execute_model_req.async_callback = self.async_callbacks[
virtual_engine] virtual_engine]
#profile.ProfRangeAutoPush('model_executor') #profile.ProfRangeAutoPush('model_executor')
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus from vllm.sequence import Sequence, SequenceStatus
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
import os
class StopChecker: class StopChecker:
"""LLMEngine helper class which separates out the logic involving stop """LLMEngine helper class which separates out the logic involving stop
...@@ -21,7 +22,6 @@ class StopChecker: ...@@ -21,7 +22,6 @@ class StopChecker:
self._max_model_len = max_model_len self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
def _get_max_model_len(self, lora_req: Optional[LoRARequest]): def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len: if lora_req and lora_req.long_lora_max_len:
...@@ -44,104 +44,53 @@ class StopChecker: ...@@ -44,104 +44,53 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet; # Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not # skip the stop string/token checks if not
if self.zero_overhead: if seq.get_output_len(self.zero_overhead) < sampling_params.min_tokens:
if seq.zero_overhead_get_output_len() < sampling_params.min_tokens: return
return
#new char count的 暂时未修改逻辑
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) if ((not sampling_params.ignore_eos)
and seq.zero_overhead_get_last_token_id() == seq.eos_token_id): and seq.get_last_token_id(self.zero_overhead) == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified # Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token # This prevents unintended exposure of the EOS token
if new_char_count and ( if new_char_count and (
not sampling_params.include_stop_str_in_output): not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count] seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
# Check if a stop token was encountered. # Check if a stop token was encountered.
# This assumes a single token produced per step. # This assumes a single token produced per step.
last_token_id = seq.zero_overhead_get_last_token_id() last_token_id = seq.get_last_token_id(self.zero_overhead)
if last_token_id in (sampling_params.stop_token_ids or ()): if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and ( if new_char_count and (
not sampling_params.include_stop_str_in_output): not sampling_params.include_stop_str_in_output):
# Remove last token # Remove last token
seq.output_text = seq.output_text[:-new_char_count] seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id seq.stop_reason = last_token_id
return return
# Check if any stop strings are matched. # Check if any stop strings are matched.
stop = self.check_stop_strings( stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop, seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output) sampling_params.include_stop_str_in_output)
if stop is not None: if stop is not None:
stop_str, truncate_to = stop stop_str, truncate_to = stop
if truncate_to != -1: if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to] seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str seq.stop_reason = stop_str
return return
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req): if seq.get_len(self.zero_overhead) > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return return
# Check if the sequence has reached max_tokens. # Check if the sequence has reached max_tokens.
if seq.zero_overhead_get_output_len() >= sampling_params.max_tokens: if seq.get_output_len(self.zero_overhead) == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else:
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if new_char_count and (
not sampling_params.include_stop_str_in_output):
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in (sampling_params.stop_token_ids or ()):
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop = self.check_stop_strings(
seq.output_text, new_char_count, sampling_params.stop,
sampling_params.include_stop_str_in_output)
if stop is not None:
stop_str, truncate_to = stop
if truncate_to != -1:
seq.output_text = seq.output_text[:truncate_to]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self._get_max_model_len(lora_req):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
@staticmethod @staticmethod
def check_stop_strings( def check_stop_strings(
......
...@@ -7,6 +7,7 @@ from array import array ...@@ -7,6 +7,7 @@ from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
import os
from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional from typing import Any, Callable, DefaultDict, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union from typing import Set, Tuple, Union
...@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct, ...@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta: Optional[int] = None _mrope_position_delta: Optional[int] = None
_first_step_flag: bool = True _first_step_flag: bool = True
_effective_length:int =0
_effective_length: int = 0
@staticmethod @staticmethod
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData": *token_counts: Tuple[int, int]) -> "SequenceData":
...@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct, ...@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct,
self._new_appended_tokens.append(token_id) self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob self._cumulative_logprob += logprob
def fix_effective_token_id(self, token_id: int,):
effect_offset = self._effective_length - len(self.output_token_ids)
if effect_offset < 0:
self._output_token_ids[effect_offset] = token_id
self._new_appended_tokens[effect_offset] = token_id
self._cached_all_token_ids[effect_offset] = token_id
self._effective_length += 1
def get_len(self) -> int: def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids) return len(self._output_token_ids) + len(self._prompt_token_ids)
def zero_overhead_get_len(self) -> int: def zero_overhead_get_len(self) -> int:
return self._effective_length + len(self._prompt_token_ids) return self._effective_length + len(self._prompt_token_ids)
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return len(self._prompt_token_ids) return len(self._prompt_token_ids)
def get_output_len(self) -> int: def get_output_len(self) -> int:
return len(self._output_token_ids) return len(self._output_token_ids)
def zero_overhead_get_output_len(self) -> Tuple[int, ...]:
return self._effective_length
def zero_overhead_get_output_len(self) -> int:
return self._effective_length
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self._cached_all_token_ids return self._cached_all_token_ids
...@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct, ...@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to # of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output. # prefill for both prompt and output.
return self.get_len() - self.get_num_computed_tokens() return self.get_len() - self.get_num_computed_tokens()
def get_last_token_id(self) -> int: def get_last_token_id(self) -> int:
if not self._output_token_ids: if not self._output_token_ids:
return self._prompt_token_ids[-1] return self._prompt_token_ids[-1]
return self._output_token_ids[-1] return self._output_token_ids[-1]
def zero_overhead_get_last_token_id(self) -> int: def zero_overhead_get_last_token_id(self) -> int:
if self._effective_length==0: if self._effective_length == 0:
return self._prompt_token_ids[-1] return self._prompt_token_ids[-1]
return self._output_token_ids[self._effective_length-1] return self._output_token_ids[self._effective_length - 1]
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.prompt_token_ids return self.prompt_token_ids
def zero_overhead_get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids[:self._effective_length]
def get_output_token_ids(self) -> Tuple[int, ...]: def get_output_token_ids(self) -> Tuple[int, ...]:
return self.output_token_ids return self.output_token_ids
...@@ -469,6 +485,7 @@ class Sequence: ...@@ -469,6 +485,7 @@ class Sequence:
self.read_offset = 0 self.read_offset = 0
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
@property @property
def n_blocks(self) -> int: def n_blocks(self) -> int:
...@@ -535,9 +552,9 @@ class Sequence: ...@@ -535,9 +552,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to """If delta is True, only new tokens since the last call to
this method are returned""" this method are returned"""
if not delta: if not delta:
return self.get_output_token_ids() return self.get_output_token_ids(self.zero_overhead)
output_len = self.get_output_len() output_len = self.get_output_len(self.zero_overhead)
# Get the number of new tokens # Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset num_new_tokens = output_len - self._last_output_token_ids_offset
...@@ -547,11 +564,16 @@ class Sequence: ...@@ -547,11 +564,16 @@ class Sequence:
if num_new_tokens == 1: if num_new_tokens == 1:
# Optimization for single decode token case # Optimization for single decode token case
# (which is what we have most of the time) # (which is what we have most of the time)
return self.data._cached_all_token_ids[-1] if self.zero_overhead:
return self.data._cached_all_token_ids[self.data._effective_length - 1]
else:
return self.data._cached_all_token_ids[-1]
if num_new_tokens == 0: if num_new_tokens == 0:
return [] return []
if self.zero_overhead:
return self.data._cached_all_token_ids[-num_new_tokens : self.data._effective_length]
return self.data._cached_all_token_ids[-num_new_tokens:] return self.data._cached_all_token_ids[-num_new_tokens:]
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
...@@ -591,34 +613,35 @@ class Sequence: ...@@ -591,34 +613,35 @@ class Sequence:
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
def fix_last_token_id(self, token_id: int) -> None: def fix_last_token_id(self, token_id: int) -> None:
self.data._output_token_ids[-1] = token_id self.data.fix_effective_token_id(token_id)
self.data._new_appended_tokens[-1] = token_id
self.data._cached_all_token_ids[-1] = token_id
def get_len(self) -> int: def get_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_len()
return self.data.get_len() return self.data.get_len()
def zero_overhead_get_len(self) -> int:
return self.data.zero_overhead_get_len()
def get_prompt_len(self) -> int: def get_prompt_len(self) -> int:
return self.data.get_prompt_len() return self.data.get_prompt_len()
def get_output_len(self) -> int: def get_output_len(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_output_len()
return self.data.get_output_len() return self.data.get_output_len()
def zero_overhead_get_output_len(self) -> int:
return self.data.zero_overhead_get_output_len()
def get_token_ids(self) -> List[int]: def get_token_ids(self) -> List[int]:
return self.data.get_token_ids() return self.data.get_token_ids()
def get_prompt_token_ids(self) -> Tuple[int, ...]: def get_prompt_token_ids(self) -> Tuple[int, ...]:
return self.data.get_prompt_token_ids() return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int: def get_last_token_id(self, zero_overhead = False) -> int:
if zero_overhead:
return self.data.zero_overhead_get_last_token_id()
return self.data.get_last_token_id() return self.data.get_last_token_id()
def zero_overhead_get_last_token_id(self) -> int:
return self.data.zero_overhead_get_last_token_id() def get_output_token_ids(self, zero_overhead = False) -> Tuple[int, ...]:
def get_output_token_ids(self) -> Tuple[int, ...]: if zero_overhead:
return self.data.zero_overhead_get_output_token_ids()
return self.data.get_output_token_ids() return self.data.get_output_token_ids()
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
import os
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
Sequence, SequenceGroup) Sequence, SequenceGroup)
...@@ -108,11 +109,11 @@ class Detokenizer: ...@@ -108,11 +109,11 @@ class Detokenizer:
Returns: Returns:
The number of characters added to the output text. The number of characters added to the output text.
""" """
all_input_ids = seq.get_token_ids() all_input_ids = seq.get_token_ids()
if self.zero_overhead: if self.zero_overhead:
eff_length=seq.get_prompt_len()+seq.data._effective_length eff_length = seq.get_prompt_len() + seq.data._effective_length
all_input_ids = seq.get_token_ids()[:eff_length] all_input_ids = seq.get_token_ids()[ : eff_length]
print(f'{all_input_ids=}')
token_id_generated_this_iteration = all_input_ids[-1] token_id_generated_this_iteration = all_input_ids[-1]
tokenizer = self.get_tokenizer_for_seq(seq) tokenizer = self.get_tokenizer_for_seq(seq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment