Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0be169ad
Commit
0be169ad
authored
Mar 25, 2025
by
lizhigong
Browse files
debug and fix some error about outputs
parent
18b9f67c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
414 deletions
+125
-414
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+24
-286
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+48
-99
vllm/sequence.py
vllm/sequence.py
+47
-24
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+6
-5
No files found.
vllm/engine/llm_engine.py
View file @
0be169ad
...
...
@@ -1244,255 +1244,8 @@ class LLMEngine:
return
None
def
fix_process_model_output
(
self
,
ctx_output_queue
,
ctx_request_outputs
,
ctx_multi_step_stream_outputs
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
now
=
time
.
time
()
if
len
(
ctx_output_queue
)
==
0
:
return
None
# Get pending async postprocessor
if
request_id
:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
is_first_step_output
,
skip
)
=
ctx_output_queue
[
0
]
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
is_first_step_output
,
skip
)
=
ctx_output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
outputs_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
if
has_multiple_outputs
:
assert
self
.
scheduler_config
.
is_multi_step
or
\
self
.
speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if
self
.
scheduler_config
.
is_multi_step
:
outputs_by_sequence_group
=
create_output_by_sequence_group
(
outputs
,
len
(
seq_group_metadata_list
))
elif
self
.
speculative_config
:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills
=
sum
(
sg
.
is_prompt
for
sg
in
seq_group_metadata_list
)
prefills
,
decodes
=
outputs
[:
num_prefills
],
outputs
[
num_prefills
:]
outputs_by_sequence_group
=
create_output_by_sequence_group
(
decodes
,
num_seq_groups
=
len
(
seq_group_metadata_list
)
-
num_prefills
)
outputs_by_sequence_group
=
[
p
.
outputs
for
p
in
prefills
]
+
outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output
=
None
elif
len
(
outputs
)
==
1
:
outputs_by_sequence_group
=
outputs
else
:
return
None
# Determine the requests we need to operate on
if
request_id
:
indices
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_meta
.
request_id
==
request_id
:
assert
i
not
in
skip
# Cannot be called twice
indices
.
append
(
i
)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if
not
indices
:
return
else
:
indices
=
range
(
len
(
seq_group_metadata_list
))
# type: ignore
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
empty_seq_indices
:
List
[
int
]
=
[]
for
i
in
indices
:
if
i
in
skip
:
continue
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
:
SequenceGroup
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
finished_before
.
append
(
i
)
continue
output
:
List
[
SequenceGroupOutput
]
if
has_multiple_outputs
:
output
=
outputs_by_sequence_group
[
i
]
else
:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
if
self
.
tree_decoding
and
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
if
len
(
valid_samples
)
==
0
:
empty_seq_indices
.
append
(
i
)
continue
if
not
is_async
:
#print("hello")
if
self
.
scheduler_config
.
is_multi_step
:
# Updates happen only if the sequence is prefill
self
.
_update_num_computed_tokens_for_multi_step_prefill
(
seq_group
,
seq_group_meta
,
is_first_step_output
)
else
:
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
or
0
)
if
outputs
:
for
o
in
outputs
:
if
(
isinstance
(
o
,
SamplerOutput
)
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
or
0
)
else
:
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
or
0
)
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
if
self
.
model_config
.
runner_type
==
"pooling"
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
if
seq_group
.
is_finished
():
finished_now
.
append
(
i
)
# Generate outputs for the requests that finished this iteration
for
i
in
finished_now
:
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
not
seq_group
.
is_prefill
():
seq_group
.
set_last_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if
request_id
:
assert
len
(
indices
)
==
1
skip
.
append
(
indices
[
0
])
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step without streaming, don't create outputs each iteration
if
not
is_last_step
and
not
ctx_multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
# Create the outputs
for
i
in
indices
:
if
i
in
skip
or
i
in
finished_before
or
i
in
finished_now
or
i
in
empty_seq_indices
:
continue
# Avoids double processing
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
not
seq_group
.
is_prefill
():
seq_group
.
set_last_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx_multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
,
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# Immediately process request outputs here (if callback is given)
if
(
ctx_request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if
is_async
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
,
skip
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
,
finished_before
)
return
None
def
_fix_last_step
(
self
,
ctx_output_queue
,
ctx_request_outputs
,
ctx_multi_step_stream_outputs
,
output
:
List
[
SamplerOutput
],
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
...
...
@@ -1514,9 +1267,7 @@ class LLMEngine:
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
sample
.
output_token
=
token_id
[
0
]
seq
.
data
.
_effective_length
+=
1
seq
.
fix_last_token_id
(
sample
.
output_token
)
self
.
fix_process_model_output
(
ctx_output_queue
,
ctx_request_outputs
,
ctx_multi_step_stream_outputs
)
break
def
_advance_to_next_step
(
...
...
@@ -1612,9 +1363,9 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
#
if allow_async_output_proc:
#
execute_model_req.async_callback = self.async_callbacks[
#
virtual_engine]
#profile.ProfRangeAutoPush('model_executor')
outputs
=
self
.
model_executor
.
execute_model
(
...
...
@@ -1637,44 +1388,32 @@ class LLMEngine:
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
ctx
.
output_queue
,
ctx
.
request_outputs
,
ctx
.
multi_step_stream_outputs
,
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
allow_async_output_proc
=
True
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps
.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append
.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_proc
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Check if need to run the usual non-async path
if
not
allow_async_output_proc
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
)
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
return
ctx
.
request_outputs
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
...
...
@@ -1820,8 +1559,7 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
not
self
.
zero_overhead
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
#profile.ProfRangeAutoPush('model_executor')
...
...
vllm/engine/output_processor/stop_checker.py
View file @
0be169ad
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
import
os
class
StopChecker
:
"""LLMEngine helper class which separates out the logic involving stop
...
...
@@ -21,7 +22,6 @@ class StopChecker:
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
if
lora_req
and
lora_req
.
long_lora_max_len
:
...
...
@@ -44,104 +44,53 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
self
.
zero_overhead
:
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
#new char count的 暂时未修改逻辑
if
seq
.
get_output_len
(
self
.
zero_overhead
)
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
zero_overhead_get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
>=
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else
:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
(
self
.
zero_overhead
)
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
(
self
.
zero_overhead
)
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
(
self
.
zero_overhead
)
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
@
staticmethod
def
check_stop_strings
(
...
...
vllm/sequence.py
View file @
0be169ad
...
...
@@ -7,6 +7,7 @@ from array import array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
functools
import
reduce
import
os
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
...
...
@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta
:
Optional
[
int
]
=
None
_first_step_flag
:
bool
=
True
_effective_length
:
int
=
0
_effective_length
:
int
=
0
@
staticmethod
def
from_prompt_token_counts
(
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
...
...
@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct,
self
.
_new_appended_tokens
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cumulative_logprob
+=
logprob
def
fix_effective_token_id
(
self
,
token_id
:
int
,):
effect_offset
=
self
.
_effective_length
-
len
(
self
.
output_token_ids
)
if
effect_offset
<
0
:
self
.
_output_token_ids
[
effect_offset
]
=
token_id
self
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
_effective_length
+=
1
def
get_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
_effective_length
+
len
(
self
.
_prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
_prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
def
zero_overhead_get_output_len
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
_effective_length
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
_effective_length
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
_cached_all_token_ids
...
...
@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
_output_token_ids
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
_effective_length
==
0
:
if
self
.
_effective_length
==
0
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
prompt_token_ids
def
zero_overhead_get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
[:
self
.
_effective_length
]
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
...
...
@@ -469,6 +485,7 @@ class Sequence:
self
.
read_offset
=
0
# Input + output tokens
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
@
property
def
n_blocks
(
self
)
->
int
:
...
...
@@ -535,9 +552,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
return
self
.
get_output_token_ids
(
self
.
zero_overhead
)
output_len
=
self
.
get_output_len
()
output_len
=
self
.
get_output_len
(
self
.
zero_overhead
)
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
...
...
@@ -547,11 +564,16 @@ class Sequence:
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
self
.
data
.
_effective_length
-
1
]
else
:
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
num_new_tokens
==
0
:
return
[]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
self
.
data
.
_effective_length
]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
...
...
@@ -591,34 +613,35 @@ class Sequence:
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
self
.
data
.
_output_token_ids
[
-
1
]
=
token_id
self
.
data
.
_new_appended_tokens
[
-
1
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
-
1
]
=
token_id
self
.
data
.
fix_effective_token_id
(
token_id
)
def
get_len
(
self
)
->
int
:
def
get_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_len
()
return
self
.
data
.
get_len
()
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_len
()
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_len
()
return
self
.
data
.
get_output_len
()
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
,
zero_overhead
=
False
)
->
Tuple
[
int
,
...]:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
...
...
vllm/transformers_utils/detokenizer.py
View file @
0be169ad
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Dict
,
List
,
Optional
import
os
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
)
...
...
@@ -108,11 +109,11 @@ class Detokenizer:
Returns:
The number of characters added to the output text.
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
if
self
.
zero_overhead
:
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
data
.
_effective_length
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
print
(
f
'
{
all_input_ids
=
}
'
)
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
data
.
_effective_length
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment