Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a37d75bb
Unverified
Commit
a37d75bb
authored
Jul 08, 2025
by
ztang2370
Committed by
GitHub
Jul 07, 2025
Browse files
[Front-end] microbatch tokenization (#19334)
Signed-off-by:
zt2370
<
ztang2370@gmail.com
>
parent
edd270bc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
288 additions
and
64 deletions
+288
-64
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+23
-16
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+73
-48
vllm/utils/__init__.py
vllm/utils/__init__.py
+192
-0
No files found.
tests/entrypoints/openai/test_serving_chat.py
View file @
a37d75bb
...
@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
...
@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.config
import
MultiModalConfig
from
vllm.config
import
MultiModalConfig
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
...
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
...
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
def
test_serving_chat_should_set_correct_max_tokens
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_should_set_correct_max_tokens
():
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
mock_engine
.
errored
=
False
...
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template
=
CHAT_TEMPLATE
,
chat_template
=
CHAT_TEMPLATE
,
chat_template_content_format
=
"auto"
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
)
request_logger
=
None
)
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
[{
messages
=
[{
...
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
)
)
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
req
.
max_tokens
=
10
req
.
max_tokens
=
10
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
@@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
)
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
@@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
15
req
.
max_tokens
=
15
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
@@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
5
req
.
max_tokens
=
5
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
...
@@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
)
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
...
@@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
100
req
.
max_tokens
=
100
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
...
@@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
...
@@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
5
req
.
max_tokens
=
5
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
def
test_serving_chat_could_load_correct_generation_config
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_could_load_correct_generation_config
():
mock_model_config
=
MockModelConfig
()
mock_model_config
=
MockModelConfig
()
mock_model_config
.
diff_sampling_param
=
{
mock_model_config
.
diff_sampling_param
=
{
...
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
...
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template
=
CHAT_TEMPLATE
,
chat_template
=
CHAT_TEMPLATE
,
chat_template_content_format
=
"auto"
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
)
request_logger
=
None
)
req
=
ChatCompletionRequest
(
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
,
messages
=
[{
messages
=
[{
...
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
...
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
)
)
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.5
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.5
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
...
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
...
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
req
.
temperature
=
0.1
req
.
temperature
=
0.1
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.1
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.1
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
...
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
...
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
req
.
temperature
=
0.0
req
.
temperature
=
0.0
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.0
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.0
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
def
test_serving_chat_did_set_correct_cache_salt
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_did_set_correct_cache_salt
():
mock_model_config
=
MockModelConfig
()
mock_model_config
=
MockModelConfig
()
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
...
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
...
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set
# By default cache_salt in the engine prompt is not set
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
"cache_salt"
not
in
mock_engine
.
generate
.
call_args
.
args
[
0
]
assert
"cache_salt"
not
in
mock_engine
.
generate
.
call_args
.
args
[
0
]
# Test with certain cache_salt
# Test with certain cache_salt
req
.
cache_salt
=
"test_salt"
req
.
cache_salt
=
"test_salt"
with
suppress
(
Exception
):
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
0
][
"cache_salt"
]
==
"test_salt"
assert
mock_engine
.
generate
.
call_args
.
args
[
0
][
"cache_salt"
]
==
"test_salt"
vllm/entrypoints/openai/serving_engine.py
View file @
a37d75bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
base64
import
base64
import
io
import
io
import
json
import
json
import
sys
import
sys
import
time
import
time
from
collections.abc
import
(
AsyncGenerator
,
Iterable
,
Iterator
,
Mapping
,
from
collections.abc
import
AsyncGenerator
,
Iterable
,
Mapping
,
Sequence
Sequence
)
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures.thread
import
ThreadPoolExecutor
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
(
Annotated
,
Any
,
Callable
,
ClassVar
,
Generic
,
Optional
,
from
typing
import
(
Annotated
,
Any
,
Callable
,
ClassVar
,
Generic
,
Optional
,
TypeVar
,
Union
,
cast
,
overload
)
TypeVar
,
Union
,
cast
,
overload
)
...
@@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
...
@@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
(
is_list_of
,
make_async
,
merge_async_iterators
,
from
vllm.utils
import
(
AsyncMicrobatchTokenizer
,
is_list_of
,
random_uuid
)
merge_async_iterators
,
random_uuid
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -226,11 +226,19 @@ class OpenAIServing:
...
@@ -226,11 +226,19 @@ class OpenAIServing:
self
.
_tokenizer_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
_tokenizer_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
_tokenize_prompt_input_async
=
make_async
(
self
.
_async_tokenizer_pool
:
dict
[
AnyTokenizer
,
self
.
_tokenize_prompt_input
,
executor
=
self
.
_tokenizer_executor
)
AsyncMicrobatchTokenizer
]
=
{}
self
.
_tokenize_prompt_input_or_inputs_async
=
make_async
(
self
.
_tokenize_prompt_input_or_inputs
,
def
_get_async_tokenizer
(
self
,
tokenizer
)
->
AsyncMicrobatchTokenizer
:
executor
=
self
.
_tokenizer_executor
)
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer
=
self
.
_async_tokenizer_pool
.
get
(
tokenizer
)
if
async_tokenizer
is
None
:
async_tokenizer
=
AsyncMicrobatchTokenizer
(
tokenizer
)
self
.
_async_tokenizer_pool
[
tokenizer
]
=
async_tokenizer
return
async_tokenizer
async
def
_preprocess
(
async
def
_preprocess
(
self
,
self
,
...
@@ -467,7 +475,7 @@ class OpenAIServing:
...
@@ -467,7 +475,7 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
def
_normalize_prompt_text_to_input
(
async
def
_normalize_prompt_text_to_input
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
...
@@ -475,38 +483,44 @@ class OpenAIServing:
...
@@ -475,38 +483,44 @@ class OpenAIServing:
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]],
add_special_tokens
:
bool
,
add_special_tokens
:
bool
,
)
->
TextTokensPrompt
:
)
->
TextTokensPrompt
:
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
if
(
self
.
model_config
.
encoder_config
is
not
None
if
(
self
.
model_config
.
encoder_config
is
not
None
and
self
.
model_config
.
encoder_config
.
get
(
and
self
.
model_config
.
encoder_config
.
get
(
"do_lower_case"
,
False
)):
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
prompt
=
prompt
.
lower
()
if
truncate_prompt_tokens
is
None
:
if
truncate_prompt_tokens
is
None
:
encoded
=
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
)
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
)
elif
truncate_prompt_tokens
<
0
:
elif
truncate_prompt_tokens
<
0
:
# Negative means we cap at the model's max length
# Negative means we cap at the model's max length
encoded
=
tokenizer
(
prompt
,
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
truncation
=
True
,
max_length
=
self
.
max_model_len
)
max_length
=
self
.
max_model_len
)
else
:
else
:
encoded
=
tokenizer
(
prompt
,
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
max_length
=
truncate_prompt_tokens
)
input_ids
=
encoded
.
input_ids
input_ids
=
encoded
.
input_ids
input_text
=
prompt
input_text
=
prompt
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
def
_normalize_prompt_tokens_to_input
(
async
def
_normalize_prompt_tokens_to_input
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
prompt_ids
:
list
[
int
],
prompt_ids
:
list
[
int
],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]],
)
->
TextTokensPrompt
:
)
->
TextTokensPrompt
:
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
if
truncate_prompt_tokens
is
None
:
if
truncate_prompt_tokens
is
None
:
input_ids
=
prompt_ids
input_ids
=
prompt_ids
elif
truncate_prompt_tokens
<
0
:
elif
truncate_prompt_tokens
<
0
:
...
@@ -514,7 +528,7 @@ class OpenAIServing:
...
@@ -514,7 +528,7 @@ class OpenAIServing:
else
:
else
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_text
=
tokenizer
.
decode
(
input_ids
)
input_text
=
await
async_
tokenizer
.
decode
(
input_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
...
@@ -578,7 +592,7 @@ class OpenAIServing:
...
@@ -578,7 +592,7 @@ class OpenAIServing:
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
def
_tokenize_prompt_input
(
async
def
_tokenize_prompt_input
_async
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
...
@@ -591,23 +605,24 @@ class OpenAIServing:
...
@@ -591,23 +605,24 @@ class OpenAIServing:
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input.
that assumes single input.
"""
"""
return
next
(
async
for
result
in
self
.
_tokenize_prompt_inputs_async
(
self
.
_tokenize_prompt_inputs
(
request
,
request
,
tokenizer
,
tokenizer
,
[
prompt_input
],
[
prompt_input
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
))
):
return
result
raise
ValueError
(
"No results yielded from tokenization"
)
def
_tokenize_prompt_inputs
(
async
def
_tokenize_prompt_inputs
_async
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
prompt_inputs
:
Iterable
[
Union
[
str
,
list
[
int
]]],
prompt_inputs
:
Iterable
[
Union
[
str
,
list
[
int
]]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
,
add_special_tokens
:
bool
=
True
,
)
->
It
erator
[
TextTokensPrompt
]:
)
->
AsyncGen
erator
[
TextTokensPrompt
,
None
]:
"""
"""
A simpler implementation of
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
...
@@ -615,7 +630,7 @@ class OpenAIServing:
...
@@ -615,7 +630,7 @@ class OpenAIServing:
"""
"""
for
text
in
prompt_inputs
:
for
text
in
prompt_inputs
:
if
isinstance
(
text
,
str
):
if
isinstance
(
text
,
str
):
yield
self
.
_normalize_prompt_text_to_input
(
yield
await
self
.
_normalize_prompt_text_to_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt
=
text
,
prompt
=
text
,
...
@@ -623,14 +638,14 @@ class OpenAIServing:
...
@@ -623,14 +638,14 @@ class OpenAIServing:
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
)
)
else
:
else
:
yield
self
.
_normalize_prompt_tokens_to_input
(
yield
await
self
.
_normalize_prompt_tokens_to_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt_ids
=
text
,
prompt_ids
=
text
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
)
)
def
_tokenize_prompt_input_or_inputs
(
async
def
_tokenize_prompt_input_or_inputs
_async
(
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
...
@@ -664,21 +679,31 @@ class OpenAIServing:
...
@@ -664,21 +679,31 @@ class OpenAIServing:
# VSCode Pyright extension should still work properly
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
# See: https://github.com/microsoft/pyright/issues/7672
inputs_text
.
extend
([
self
.
_normalize_prompt_text_to_input
(
# Parse and batch the input prompts
batch_inputs
=
parse_and_batch_prompt
(
input_or_inputs
)
# Process each input in the batch concurrently
tasks
=
[]
for
prompt_input
in
batch_inputs
:
if
prompt_input
[
"is_tokens"
]
is
False
:
task
=
self
.
_normalize_prompt_text_to_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt
=
prompt_input
[
"content"
],
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
)
add_special_tokens
=
add_special_tokens
)
if
prompt_input
[
"is_tokens"
]
is
False
else
else
:
self
.
_normalize_prompt_tokens_to_input
(
task
=
self
.
_normalize_prompt_tokens_to_input
(
request
,
request
,
tokenizer
,
tokenizer
,
prompt_ids
=
prompt_input
[
"content"
],
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
)
truncate_prompt_tokens
=
truncate_prompt_tokens
)
for
prompt_input
in
parse_and_batch_prompt
(
input_or_inputs
)
tasks
.
append
(
task
)
])
# Wait for all tokenization tasks to complete
results
=
await
asyncio
.
gather
(
*
tasks
)
inputs_text
.
extend
(
results
)
return
inputs_text
,
inputs_embeds
return
inputs_text
,
inputs_embeds
...
...
vllm/utils/__init__.py
View file @
a37d75bb
...
@@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
...
@@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Collection
,
Generator
,
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Collection
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
KeysView
,
Mapping
,
Hashable
,
Iterable
,
Iterator
,
KeysView
,
Mapping
,
Sequence
)
Sequence
)
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures.process
import
ProcessPoolExecutor
from
concurrent.futures.process
import
ProcessPoolExecutor
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
...
@@ -64,6 +65,7 @@ import zmq.asyncio
...
@@ -64,6 +65,7 @@ import zmq.asyncio
from
packaging
import
version
from
packaging
import
version
from
packaging.version
import
Version
from
packaging.version
import
Version
from
torch.library
import
Library
from
torch.library
import
Library
from
transformers.tokenization_utils_base
import
BatchEncoding
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -507,6 +509,196 @@ def random_uuid() -> str:
...
@@ -507,6 +509,196 @@ def random_uuid() -> str:
return
str
(
uuid
.
uuid4
().
hex
)
return
str
(
uuid
.
uuid4
().
hex
)
class
AsyncMicrobatchTokenizer
:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def
__init__
(
self
,
tokenizer
,
max_batch_size
:
int
=
32
,
batch_wait_timeout_s
:
float
=
0.002
,
)
->
None
:
self
.
tokenizer
=
tokenizer
self
.
max_batch_size
=
max_batch_size
self
.
batch_wait_timeout_s
=
batch_wait_timeout_s
self
.
_loop
=
asyncio
.
get_running_loop
()
self
.
_queues
:
dict
[
tuple
,
asyncio
.
Queue
[
Union
[
tuple
[
str
,
dict
,
asyncio
.
Future
],
tuple
[
list
[
int
],
asyncio
.
Future
]]]]
=
{}
self
.
_batcher_tasks
:
list
[
asyncio
.
Task
]
=
[]
# Single-thread executor for blocking tokenizer calls.
self
.
_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
# === Public async API ===
async
def
__call__
(
self
,
prompt
,
**
kwargs
):
result_future
:
asyncio
.
Future
=
self
.
_loop
.
create_future
()
key
=
self
.
_queue_key
(
"encode"
,
kwargs
)
queue
=
self
.
_get_queue
(
self
.
_loop
,
key
)
await
queue
.
put
((
prompt
,
kwargs
,
result_future
))
return
await
result_future
async
def
decode
(
self
,
token_ids
,
**
kwargs
):
result_future
:
asyncio
.
Future
=
self
.
_loop
.
create_future
()
key
=
self
.
_queue_key
(
"decode"
,
kwargs
)
queue
=
self
.
_get_queue
(
self
.
_loop
,
key
)
await
queue
.
put
((
token_ids
,
result_future
))
return
await
result_future
# === Internal helpers ===
def
_get_queue
(
self
,
loop
:
asyncio
.
AbstractEventLoop
,
key
:
tuple
)
->
asyncio
.
Queue
[
Union
[
tuple
[
str
,
dict
,
asyncio
.
Future
],
tuple
[
list
[
int
],
asyncio
.
Future
]]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue
=
self
.
_queues
.
get
(
key
)
if
queue
is
None
:
self
.
_queues
[
key
]
=
queue
=
asyncio
.
Queue
()
if
key
[
0
]
==
"encode"
:
can_batch
=
key
[
1
]
!=
"other"
coro
=
self
.
_batch_encode_loop
(
queue
,
can_batch
)
else
:
assert
key
[
0
]
==
"decode"
,
\
f
"Unknown operation type:
{
key
[
0
]
}
."
coro
=
self
.
_batch_decode_loop
(
queue
)
self
.
_batcher_tasks
.
append
(
loop
.
create_task
(
coro
))
return
queue
async
def
_batch_encode_loop
(
self
,
queue
:
asyncio
.
Queue
,
can_batch
:
bool
):
"""Batch incoming encode requests for efficiency."""
while
True
:
prompt
,
kwargs
,
result_future
=
await
queue
.
get
()
prompts
=
[
prompt
]
kwargs_list
=
[
kwargs
]
result_futures
=
[
result_future
]
deadline
=
self
.
_loop
.
time
()
+
self
.
batch_wait_timeout_s
while
len
(
prompts
)
<
self
.
max_batch_size
:
timeout
=
deadline
-
self
.
_loop
.
time
()
if
timeout
<=
0
:
break
try
:
prompt
,
kwargs
,
result_future
=
await
asyncio
.
wait_for
(
queue
.
get
(),
timeout
)
prompts
.
append
(
prompt
)
result_futures
.
append
(
result_future
)
if
not
can_batch
:
kwargs_list
.
append
(
kwargs
)
except
asyncio
.
TimeoutError
:
break
try
:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if
can_batch
and
len
(
prompts
)
>
1
:
encode_fn
=
partial
(
self
.
tokenizer
,
prompts
,
**
kwargs
)
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
encode_fn
)
for
i
,
fut
in
enumerate
(
result_futures
):
if
not
fut
.
done
():
data
=
{
k
:
v
[
i
]
for
k
,
v
in
results
.
items
()}
fut
.
set_result
(
BatchEncoding
(
data
))
else
:
encode_fn
=
lambda
prompts
=
prompts
,
kwargs
=
kwargs_list
:
[
self
.
tokenizer
(
p
,
**
kw
)
for
p
,
kw
in
zip
(
prompts
,
kwargs
)
]
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
encode_fn
)
for
fut
,
res
in
zip
(
result_futures
,
results
):
if
not
fut
.
done
():
fut
.
set_result
(
res
)
except
Exception
as
e
:
for
fut
in
result_futures
:
if
not
fut
.
done
():
fut
.
set_exception
(
e
)
async
def
_batch_decode_loop
(
self
,
queue
:
asyncio
.
Queue
):
"""Batch incoming decode requests for efficiency."""
while
True
:
token_ids
,
result_future
=
await
queue
.
get
()
token_ids_list
=
[
token_ids
]
result_futures
=
[
result_future
]
deadline
=
self
.
_loop
.
time
()
+
self
.
batch_wait_timeout_s
while
len
(
token_ids_list
)
<
self
.
max_batch_size
:
timeout
=
deadline
-
self
.
_loop
.
time
()
if
timeout
<=
0
:
break
try
:
token_ids
,
result_future
=
await
asyncio
.
wait_for
(
queue
.
get
(),
timeout
)
token_ids_list
.
append
(
token_ids
)
result_futures
.
append
(
result_future
)
except
asyncio
.
TimeoutError
:
break
try
:
# Perform a single batched decode call for all requests
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
self
.
tokenizer
.
batch_decode
,
token_ids_list
)
for
fut
,
res
in
zip
(
result_futures
,
results
):
if
not
fut
.
done
():
fut
.
set_result
(
res
)
except
Exception
as
e
:
for
fut
in
result_futures
:
if
not
fut
.
done
():
fut
.
set_exception
(
e
)
def
_queue_key
(
self
,
op
:
str
,
kwargs
:
dict
)
->
tuple
:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if
op
==
"decode"
:
return
(
"decode"
,
)
add_special_tokens
=
kwargs
.
get
(
"add_special_tokens"
,
True
)
truncation
=
kwargs
.
get
(
"truncation"
,
False
)
max_length
=
kwargs
.
get
(
"max_length"
)
if
not
truncation
:
return
(
"encode"
,
add_special_tokens
,
False
,
None
)
model_max
=
getattr
(
self
.
tokenizer
,
"model_max_length"
,
None
)
if
max_length
is
None
or
(
model_max
is
not
None
and
max_length
==
model_max
):
return
(
"encode"
,
add_special_tokens
,
True
,
"model_max"
)
return
(
"encode"
,
"other"
)
def
__del__
(
self
):
for
task
in
self
.
_batcher_tasks
:
if
not
task
.
done
():
task
.
cancel
()
def
make_async
(
def
make_async
(
func
:
Callable
[
P
,
T
],
func
:
Callable
[
P
,
T
],
executor
:
Optional
[
concurrent
.
futures
.
Executor
]
=
None
executor
:
Optional
[
concurrent
.
futures
.
Executor
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment