Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
052059d9
Commit
052059d9
authored
Mar 24, 2025
by
guanyu1
Browse files
detok修改
parent
2344d22e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
56 deletions
+127
-56
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+9
-5
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+97
-46
vllm/sequence.py
vllm/sequence.py
+18
-5
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+3
-0
No files found.
benchmarks/benchmark_throughput.py
View file @
052059d9
...
@@ -7,7 +7,7 @@ import random
...
@@ -7,7 +7,7 @@ import random
import
time
import
time
from
functools
import
cache
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
os
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
uvloop
import
uvloop
...
@@ -179,7 +179,7 @@ def run_vllm(
...
@@ -179,7 +179,7 @@ def run_vllm(
sampling_params
:
List
[
SamplingParams
]
=
[]
sampling_params
:
List
[
SamplingParams
]
=
[]
for
request
in
requests
:
for
request
in
requests
:
prompts
.
append
(
prompts
.
append
(
TextPrompt
(
prompt
=
request
.
prompt
,
TextPrompt
(
prompt
=
"helloworld"
,
multi_modal_data
=
request
.
multi_modal_data
))
multi_modal_data
=
request
.
multi_modal_data
))
sampling_params
.
append
(
sampling_params
.
append
(
SamplingParams
(
SamplingParams
(
...
@@ -205,21 +205,25 @@ def run_vllm(
...
@@ -205,21 +205,25 @@ def run_vllm(
dummy_prompts
:
List
[
PromptType
]
=
[{
dummy_prompts
:
List
[
PromptType
]
=
[{
"prompt_token_ids"
:
batch
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
print
(
f
'
{
os
.
environ
.
get
(
"VLLM_ZERO_OVERHEAD"
)
==
"1"
}
'
)
print
(
"Warming up..."
)
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
num_iters_warmup
),
desc
=
"Warmup iterations"
):
for
_
in
tqdm
(
range
(
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
dummy_prompts
,
llm
.
generate
(
dummy_prompts
,
sampling_params
=
warmup_sampling_params
,
sampling_params
=
warmup_sampling_params
,
use_tqdm
=
False
)
use_tqdm
=
False
)
use_beam_search
=
False
use_beam_search
=
False
print
(
"testing"
)
if
not
use_beam_search
:
if
not
use_beam_search
:
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
sampling_params
,
lora_request
=
lora_requests
,
lora_request
=
lora_requests
,
use_tqdm
=
True
)
use_tqdm
=
True
)
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"test生成的文本:
{
generated_text
}
"
)
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
else
:
else
:
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
assert
lora_requests
is
None
,
"BeamSearch API does not support LoRA"
...
...
vllm/engine/output_processor/stop_checker.py
View file @
052059d9
...
@@ -42,53 +42,104 @@ class StopChecker:
...
@@ -42,53 +42,104 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
if
self
.
zero_overhead
:
return
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
#new char count的 暂时未修改逻辑
# Check if the sequence has generated the EOS token.
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
return
# Check if a stop token was encountered.
# Check if a stop token was encountered.
# This assumes a single token produced per step.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
zero_overhead_get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
seq
.
stop_reason
=
last_token_id
return
return
# Check if any stop strings are matched.
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
seq
.
stop_reason
=
stop_str
return
return
# Check if the sequence has reached max_model_len.
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
# Check if the sequence has reached max_tokens.
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
if
seq
.
zero_overhead_get_output_len
()
>=
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
return
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
else
:
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
@
staticmethod
@
staticmethod
def
check_stop_strings
(
def
check_stop_strings
(
...
...
vllm/sequence.py
View file @
052059d9
...
@@ -177,7 +177,7 @@ class SequenceData(msgspec.Struct,
...
@@ -177,7 +177,7 @@ class SequenceData(msgspec.Struct,
_mrope_position_delta
:
Optional
[
int
]
=
None
_mrope_position_delta
:
Optional
[
int
]
=
None
_first_step_flag
:
bool
=
True
_first_step_flag
:
bool
=
True
_effective_length
:
int
=
0
@
staticmethod
@
staticmethod
def
from_prompt_token_counts
(
def
from_prompt_token_counts
(
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
...
@@ -310,12 +310,16 @@ class SequenceData(msgspec.Struct,
...
@@ -310,12 +310,16 @@ class SequenceData(msgspec.Struct,
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
_effective_length
+
len
(
self
.
_prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
_prompt_token_ids
)
return
len
(
self
.
_prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
return
len
(
self
.
_output_token_ids
)
def
zero_overhead_get_output_len
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
_effective_length
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
_cached_all_token_ids
return
self
.
_cached_all_token_ids
...
@@ -372,7 +376,11 @@ class SequenceData(msgspec.Struct,
...
@@ -372,7 +376,11 @@ class SequenceData(msgspec.Struct,
if
not
self
.
_output_token_ids
:
if
not
self
.
_output_token_ids
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
_effective_length
==
0
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
prompt_token_ids
return
self
.
prompt_token_ids
...
@@ -589,13 +597,17 @@ class Sequence:
...
@@ -589,13 +597,17 @@ class Sequence:
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
self
.
data
.
get_len
()
return
self
.
data
.
get_len
()
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_len
()
def
get_prompt_len
(
self
)
->
int
:
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
)
->
int
:
return
self
.
data
.
get_output_len
()
return
self
.
data
.
get_output_len
()
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
return
self
.
data
.
get_token_ids
()
...
@@ -604,7 +616,8 @@ class Sequence:
...
@@ -604,7 +616,8 @@ class Sequence:
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
data
.
get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
...
...
vllm/transformers_utils/detokenizer.py
View file @
052059d9
...
@@ -108,6 +108,9 @@ class Detokenizer:
...
@@ -108,6 +108,9 @@ class Detokenizer:
The number of characters added to the output text.
The number of characters added to the output text.
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
if
self
.
zero_overhead
:
all_input_ids
=
seq
.
get_token_ids
()[:
seq
.
get_prompt_len
()
+
self
.
data
.
_effective_length
]
print
(
f
'
{
all_input_ids
=
}
'
)
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment