Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0be169ad
Commit
0be169ad
authored
Mar 25, 2025
by
lizhigong
Browse files
debug and fix some error about outputs
parent
18b9f67c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
414 deletions
+125
-414
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+24
-286
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+48
-99
vllm/sequence.py
vllm/sequence.py
+47
-24
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+6
-5
No files found.
vllm/engine/llm_engine.py
View file @
0be169ad
...
@@ -1244,255 +1244,8 @@ class LLMEngine:
...
@@ -1244,255 +1244,8 @@ class LLMEngine:
return
None
return
None
def
fix_process_model_output
(
self
,
ctx_output_queue
,
ctx_request_outputs
,
ctx_multi_step_stream_outputs
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
now
=
time
.
time
()
if
len
(
ctx_output_queue
)
==
0
:
return
None
# Get pending async postprocessor
if
request_id
:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
is_first_step_output
,
skip
)
=
ctx_output_queue
[
0
]
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
is_first_step_output
,
skip
)
=
ctx_output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
outputs_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
if
has_multiple_outputs
:
assert
self
.
scheduler_config
.
is_multi_step
or
\
self
.
speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if
self
.
scheduler_config
.
is_multi_step
:
outputs_by_sequence_group
=
create_output_by_sequence_group
(
outputs
,
len
(
seq_group_metadata_list
))
elif
self
.
speculative_config
:
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger chunk
# processing without having to pad or copy over prompts K times
# to match decodes structure (costly with prompt_logprobs).
num_prefills
=
sum
(
sg
.
is_prompt
for
sg
in
seq_group_metadata_list
)
prefills
,
decodes
=
outputs
[:
num_prefills
],
outputs
[
num_prefills
:]
outputs_by_sequence_group
=
create_output_by_sequence_group
(
decodes
,
num_seq_groups
=
len
(
seq_group_metadata_list
)
-
num_prefills
)
outputs_by_sequence_group
=
[
p
.
outputs
for
p
in
prefills
]
+
outputs_by_sequence_group
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output
=
None
elif
len
(
outputs
)
==
1
:
outputs_by_sequence_group
=
outputs
else
:
return
None
# Determine the requests we need to operate on
if
request_id
:
indices
=
[]
for
i
,
seq_group_meta
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_meta
.
request_id
==
request_id
:
assert
i
not
in
skip
# Cannot be called twice
indices
.
append
(
i
)
break
# If the request_id was not found, then it means that
# this is a new request that has no pending async
# postprocessor
if
not
indices
:
return
else
:
indices
=
range
(
len
(
seq_group_metadata_list
))
# type: ignore
finished_before
:
List
[
int
]
=
[]
finished_now
:
List
[
int
]
=
[]
empty_seq_indices
:
List
[
int
]
=
[]
for
i
in
indices
:
if
i
in
skip
:
continue
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
:
SequenceGroup
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
finished_before
.
append
(
i
)
continue
output
:
List
[
SequenceGroupOutput
]
if
has_multiple_outputs
:
output
=
outputs_by_sequence_group
[
i
]
else
:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
# tree style speculative decoding may generate empty output in first step
if
self
.
tree_decoding
and
outputs
and
isinstance
(
output
[
0
],
CompletionSequenceGroupOutput
):
samples
=
[
o
.
samples
[
0
]
for
o
in
output
]
valid_samples
=
[
sample
for
sample
in
samples
if
sample
.
output_token
!=
VLLM_INVALID_TOKEN_ID
]
if
len
(
valid_samples
)
==
0
:
empty_seq_indices
.
append
(
i
)
continue
if
not
is_async
:
#print("hello")
if
self
.
scheduler_config
.
is_multi_step
:
# Updates happen only if the sequence is prefill
self
.
_update_num_computed_tokens_for_multi_step_prefill
(
seq_group
,
seq_group_meta
,
is_first_step_output
)
else
:
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
or
0
)
if
outputs
:
for
o
in
outputs
:
if
(
isinstance
(
o
,
SamplerOutput
)
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
or
0
)
else
:
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
or
0
)
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
if
self
.
model_config
.
runner_type
==
"pooling"
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
if
seq_group
.
is_finished
():
finished_now
.
append
(
i
)
# Generate outputs for the requests that finished this iteration
for
i
in
finished_now
:
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
not
seq_group
.
is_prefill
():
seq_group
.
set_last_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
if
request_id
:
assert
len
(
indices
)
==
1
skip
.
append
(
indices
[
0
])
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
# Free currently finished requests
if
finished_now
:
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_finished_seq_groups
()
# For multi-step without streaming, don't create outputs each iteration
if
not
is_last_step
and
not
ctx_multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
(
finished_now
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
# Create the outputs
for
i
in
indices
:
if
i
in
skip
or
i
in
finished_before
or
i
in
finished_now
or
i
in
empty_seq_indices
:
continue
# Avoids double processing
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
if
not
seq_group
.
is_prefill
():
seq_group
.
set_last_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# For multi-step with streaming, create outputs each iteration
if
not
is_last_step
and
ctx_multi_step_stream_outputs
:
# Immediately process request outputs here (if callback is given)
if
self
.
process_request_outputs_callback
is
not
None
:
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
return
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
:
params
=
seq_group
.
sampling_params
if
params
is
not
None
and
params
.
output_kind
==
(
RequestOutputKind
.
DELTA
)
and
not
seq_group
.
is_finished
():
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
,
)
if
request_output
:
ctx_request_outputs
.
append
(
request_output
)
# Immediately process request outputs here (if callback is given)
if
(
ctx_request_outputs
and
self
.
process_request_outputs_callback
is
not
None
):
self
.
process_request_outputs_callback
(
ctx_request_outputs
)
ctx_request_outputs
.
clear
()
# For async case, we need to record the stats here.
# For non-async case, the stats are done in the
# LLMEngine/AsyncLLMEngine directly
if
is_async
:
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
outputs
,
finished_before
,
skip
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
,
finished_before
)
return
None
def
_fix_last_step
(
def
_fix_last_step
(
self
,
ctx_output_queue
,
ctx_request_outputs
,
self
,
output
:
List
[
SamplerOutput
],
ctx_multi_step_stream_outputs
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
...
@@ -1514,9 +1267,7 @@ class LLMEngine:
...
@@ -1514,9 +1267,7 @@ class LLMEngine:
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
if
seq
.
seq_id
==
seq_id
:
sample
.
output_token
=
token_id
[
0
]
sample
.
output_token
=
token_id
[
0
]
seq
.
data
.
_effective_length
+=
1
seq
.
fix_last_token_id
(
sample
.
output_token
)
seq
.
fix_last_token_id
(
sample
.
output_token
)
self
.
fix_process_model_output
(
ctx_output_queue
,
ctx_request_outputs
,
ctx_multi_step_stream_outputs
)
break
break
def
_advance_to_next_step
(
def
_advance_to_next_step
(
...
@@ -1612,9 +1363,9 @@ class LLMEngine:
...
@@ -1612,9 +1363,9 @@ class LLMEngine:
last_sampled_token_ids
=
last_sampled_token_ids
,
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
last_outputs_sample
=
last_outputs_tensor
)
if
allow_async_output_proc
:
#
if allow_async_output_proc:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
#
execute_model_req.async_callback = self.async_callbacks[
virtual_engine
]
#
virtual_engine]
#profile.ProfRangeAutoPush('model_executor')
#profile.ProfRangeAutoPush('model_executor')
outputs
=
self
.
model_executor
.
execute_model
(
outputs
=
self
.
model_executor
.
execute_model
(
...
@@ -1637,44 +1388,32 @@ class LLMEngine:
...
@@ -1637,44 +1388,32 @@ class LLMEngine:
ctx
.
scheduler_outputs
=
scheduler_outputs
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
self
.
_fix_last_step
(
ctx
.
output_queue
,
ctx
.
request_outputs
,
ctx
.
multi_step_stream_outputs
,
outputs
,
seq_group_metadata_list
,
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
)
allow_async_output_proc
=
True
# is_first_step_output is True only when the num_steps of all
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# the sequences are 1. When the num_steps > 1,
# clear the cache if we have finished all the steps
.
# multi_step_model_runner does the first-step output append
.
if
self
.
scheduler_config
.
is_multi_step
:
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# is_first_step_output is True only when the num_steps of all
# Add results to the output_queue
# the sequences are 1. When the num_steps > 1,
ctx
.
append_output
(
outputs
=
outputs
,
# multi_step_model_runner does the first-step output append.
seq_group_metadata_list
=
seq_group_metadata_list
,
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
scheduler_outputs
=
scheduler_outputs
,
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Add results to the output_queue
# Check if need to run the usual non-async path
ctx
.
append_output
(
outputs
=
outputs
,
#if not allow_async_output_proc:
seq_group_metadata_list
=
seq_group_metadata_list
,
self
.
_process_model_outputs
(
ctx
=
ctx
)
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_proc
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
# Log stats.
if
not
allow_async_output_proc
:
self
.
do_log_stats
(
scheduler_outputs
,
outputs
)
self
.
_process_model_outputs
(
ctx
=
ctx
)
# Log stats.
# Tracing
self
.
do_log_stats
(
scheduler_outputs
,
outputs
)
self
.
do_tracing
(
scheduler_outputs
)
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
else
:
# Multi-step case
return
ctx
.
request_outputs
#profile.ProfRangeAutoPush('has_unfinish')
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
if
not
self
.
has_unfinished_requests
():
...
@@ -1820,8 +1559,7 @@ class LLMEngine:
...
@@ -1820,8 +1559,7 @@ class LLMEngine:
# to each of the non-last PP stages for in-place prepare_input.
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
if
allow_async_output_proc
:
if
not
self
.
zero_overhead
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
virtual_engine
]
#profile.ProfRangeAutoPush('model_executor')
#profile.ProfRangeAutoPush('model_executor')
...
...
vllm/engine/output_processor/stop_checker.py
View file @
0be169ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
import
os
class
StopChecker
:
class
StopChecker
:
"""LLMEngine helper class which separates out the logic involving stop
"""LLMEngine helper class which separates out the logic involving stop
...
@@ -21,7 +22,6 @@ class StopChecker:
...
@@ -21,7 +22,6 @@ class StopChecker:
self
.
_max_model_len
=
max_model_len
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
if
lora_req
and
lora_req
.
long_lora_max_len
:
if
lora_req
and
lora_req
.
long_lora_max_len
:
...
@@ -44,104 +44,53 @@ class StopChecker:
...
@@ -44,104 +44,53 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
# skip the stop string/token checks if not
if
self
.
zero_overhead
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
<
sampling_params
.
min_tokens
:
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
return
#new char count的 暂时未修改逻辑
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
get_last_token_id
(
self
.
zero_overhead
)
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
# Check if a stop token was encountered.
# Check if a stop token was encountered.
# This assumes a single token produced per step.
# This assumes a single token produced per step.
last_token_id
=
seq
.
zero_overhead_get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
seq
.
stop_reason
=
last_token_id
return
return
# Check if any stop strings are matched.
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
seq
.
stop_reason
=
stop_str
return
return
# Check if the sequence has reached max_model_len.
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
if
seq
.
get_len
(
self
.
zero_overhead
)
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
# Check if the sequence has reached max_tokens.
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
>=
sampling_params
.
max_tokens
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else
:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
@
staticmethod
@
staticmethod
def
check_stop_strings
(
def
check_stop_strings
(
...
...
vllm/sequence.py
View file @
0be169ad
...
@@ -7,6 +7,7 @@ from array import array
...
@@ -7,6 +7,7 @@ from array import array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
reduce
from
functools
import
reduce
import
os
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
from
typing
import
Set
,
Tuple
,
Union
...
@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct,
...
@@ -177,7 +178,9 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta
:
Optional
[
int
]
=
None
_mrope_position_delta
:
Optional
[
int
]
=
None
_first_step_flag
:
bool
=
True
_first_step_flag
:
bool
=
True
_effective_length
:
int
=
0
_effective_length
:
int
=
0
@
staticmethod
@
staticmethod
def
from_prompt_token_counts
(
def
from_prompt_token_counts
(
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
...
@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct,
...
@@ -307,20 +310,30 @@ class SequenceData(msgspec.Struct,
self
.
_new_appended_tokens
.
append
(
token_id
)
self
.
_new_appended_tokens
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cumulative_logprob
+=
logprob
self
.
_cumulative_logprob
+=
logprob
def
fix_effective_token_id
(
self
,
token_id
:
int
,):
effect_offset
=
self
.
_effective_length
-
len
(
self
.
output_token_ids
)
if
effect_offset
<
0
:
self
.
_output_token_ids
[
effect_offset
]
=
token_id
self
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
_effective_length
+=
1
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
def
zero_overhead_get_len
(
self
)
->
int
:
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
_effective_length
+
len
(
self
.
_prompt_token_ids
)
return
self
.
_effective_length
+
len
(
self
.
_prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
_prompt_token_ids
)
return
len
(
self
.
_prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
return
len
(
self
.
_output_token_ids
)
def
zero_overhead_get_output_len
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
_effective_length
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
_effective_length
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
_cached_all_token_ids
return
self
.
_cached_all_token_ids
...
@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct,
...
@@ -371,19 +384,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
# prefill for both prompt and output.
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
_output_token_ids
:
if
not
self
.
_output_token_ids
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
_effective_length
==
0
:
if
self
.
_effective_length
==
0
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
prompt_token_ids
return
self
.
prompt_token_ids
def
zero_overhead_get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
[:
self
.
_effective_length
]
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
return
self
.
output_token_ids
...
@@ -469,6 +485,7 @@ class Sequence:
...
@@ -469,6 +485,7 @@ class Sequence:
self
.
read_offset
=
0
self
.
read_offset
=
0
# Input + output tokens
# Input + output tokens
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
@
property
@
property
def
n_blocks
(
self
)
->
int
:
def
n_blocks
(
self
)
->
int
:
...
@@ -535,9 +552,9 @@ class Sequence:
...
@@ -535,9 +552,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to
"""If delta is True, only new tokens since the last call to
this method are returned"""
this method are returned"""
if
not
delta
:
if
not
delta
:
return
self
.
get_output_token_ids
()
return
self
.
get_output_token_ids
(
self
.
zero_overhead
)
output_len
=
self
.
get_output_len
()
output_len
=
self
.
get_output_len
(
self
.
zero_overhead
)
# Get the number of new tokens
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
...
@@ -547,11 +564,16 @@ class Sequence:
...
@@ -547,11 +564,16 @@ class Sequence:
if
num_new_tokens
==
1
:
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# Optimization for single decode token case
# (which is what we have most of the time)
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
self
.
data
.
_effective_length
-
1
]
else
:
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
num_new_tokens
==
0
:
if
num_new_tokens
==
0
:
return
[]
return
[]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
self
.
data
.
_effective_length
]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
...
@@ -591,34 +613,35 @@ class Sequence:
...
@@ -591,34 +613,35 @@ class Sequence:
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
self
.
data
.
_output_token_ids
[
-
1
]
=
token_id
self
.
data
.
fix_effective_token_id
(
token_id
)
self
.
data
.
_new_appended_tokens
[
-
1
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
-
1
]
=
token_id
def
get_len
(
self
)
->
int
:
def
get_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_len
()
return
self
.
data
.
get_len
()
return
self
.
data
.
get_len
()
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_len
()
def
get_prompt_len
(
self
)
->
int
:
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_len
()
return
self
.
data
.
get_output_len
()
return
self
.
data
.
get_output_len
()
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
return
self
.
data
.
get_token_ids
()
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
data
.
get_prompt_token_ids
()
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
def
get_output_token_ids
(
self
,
zero_overhead
=
False
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
def
get_cumulative_logprob
(
self
)
->
float
:
...
...
vllm/transformers_utils/detokenizer.py
View file @
0be169ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
os
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Logprob
,
SamplingParams
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Logprob
,
SamplingParams
,
Sequence
,
SequenceGroup
)
Sequence
,
SequenceGroup
)
...
@@ -108,11 +109,11 @@ class Detokenizer:
...
@@ -108,11 +109,11 @@ class Detokenizer:
Returns:
Returns:
The number of characters added to the output text.
The number of characters added to the output text.
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
if
self
.
zero_overhead
:
if
self
.
zero_overhead
:
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
data
.
_effective_length
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
data
.
_effective_length
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
print
(
f
'
{
all_input_ids
=
}
'
)
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment