Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a37d75bb
Unverified
Commit
a37d75bb
authored
Jul 08, 2025
by
ztang2370
Committed by
GitHub
Jul 07, 2025
Browse files
[Front-end] microbatch tokenization (#19334)
Signed-off-by:
zt2370
<
ztang2370@gmail.com
>
parent
edd270bc
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
288 additions
and
64 deletions
+288
-64
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+23
-16
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+73
-48
vllm/utils/__init__.py
vllm/utils/__init__.py
+192
-0
No files found.
tests/entrypoints/openai/test_serving_chat.py
View file @
a37d75bb
...
...
@@ -7,6 +7,8 @@ from dataclasses import dataclass, field
from
typing
import
Any
,
Optional
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.config
import
MultiModalConfig
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
...
...
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
assert
serving_completion
.
chat_template
==
CHAT_TEMPLATE
def
test_serving_chat_should_set_correct_max_tokens
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_should_set_correct_max_tokens
():
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
mock_engine
.
get_tokenizer
.
return_value
=
get_tokenizer
(
MODEL_NAME
)
mock_engine
.
errored
=
False
...
...
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
chat_template
=
CHAT_TEMPLATE
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
)
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
...
...
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
req
.
max_tokens
=
10
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
...
@@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
...
@@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
15
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
10
...
...
@@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
5
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
...
...
@@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens():
)
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
...
...
@@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
100
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
93
...
...
@@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens():
req
.
max_tokens
=
5
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
max_tokens
==
5
def
test_serving_chat_could_load_correct_generation_config
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_could_load_correct_generation_config
():
mock_model_config
=
MockModelConfig
()
mock_model_config
.
diff_sampling_param
=
{
...
...
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
chat_template
=
CHAT_TEMPLATE
,
chat_template_content_format
=
"auto"
,
request_logger
=
None
)
req
=
ChatCompletionRequest
(
model
=
MODEL_NAME
,
messages
=
[{
...
...
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
)
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.5
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
...
...
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
req
.
temperature
=
0.1
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.1
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
...
...
@@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config():
req
.
temperature
=
0.0
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
temperature
==
0.0
assert
mock_engine
.
generate
.
call_args
.
args
[
1
].
repetition_penalty
==
1.05
def
test_serving_chat_did_set_correct_cache_salt
():
@
pytest
.
mark
.
asyncio
async
def
test_serving_chat_did_set_correct_cache_salt
():
mock_model_config
=
MockModelConfig
()
mock_engine
=
MagicMock
(
spec
=
MQLLMEngineClient
)
...
...
@@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt():
# By default cache_salt in the engine prompt is not set
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
"cache_salt"
not
in
mock_engine
.
generate
.
call_args
.
args
[
0
]
# Test with certain cache_salt
req
.
cache_salt
=
"test_salt"
with
suppress
(
Exception
):
a
syncio
.
run
(
serving_chat
.
create_chat_completion
(
req
)
)
a
wait
serving_chat
.
create_chat_completion
(
req
)
assert
mock_engine
.
generate
.
call_args
.
args
[
0
][
"cache_salt"
]
==
"test_salt"
vllm/entrypoints/openai/serving_engine.py
View file @
a37d75bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
base64
import
io
import
json
import
sys
import
time
from
collections.abc
import
(
AsyncGenerator
,
Iterable
,
Iterator
,
Mapping
,
Sequence
)
from
concurrent.futures.thread
import
ThreadPoolExecutor
from
collections.abc
import
AsyncGenerator
,
Iterable
,
Mapping
,
Sequence
from
concurrent.futures
import
ThreadPoolExecutor
from
http
import
HTTPStatus
from
typing
import
(
Annotated
,
Any
,
Callable
,
ClassVar
,
Generic
,
Optional
,
TypeVar
,
Union
,
cast
,
overload
)
...
...
@@ -79,8 +79,8 @@ from vllm.sequence import Logprob, PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
(
is_list_of
,
make_async
,
merge_async_iterators
,
random_uuid
)
from
vllm.utils
import
(
AsyncMicrobatchTokenizer
,
is_list_of
,
merge_async_iterators
,
random_uuid
)
logger
=
init_logger
(
__name__
)
...
...
@@ -226,11 +226,19 @@ class OpenAIServing:
self
.
_tokenizer_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
_tokenize_prompt_input_async
=
make_async
(
self
.
_tokenize_prompt_input
,
executor
=
self
.
_tokenizer_executor
)
self
.
_tokenize_prompt_input_or_inputs_async
=
make_async
(
self
.
_tokenize_prompt_input_or_inputs
,
executor
=
self
.
_tokenizer_executor
)
self
.
_async_tokenizer_pool
:
dict
[
AnyTokenizer
,
AsyncMicrobatchTokenizer
]
=
{}
def
_get_async_tokenizer
(
self
,
tokenizer
)
->
AsyncMicrobatchTokenizer
:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer
=
self
.
_async_tokenizer_pool
.
get
(
tokenizer
)
if
async_tokenizer
is
None
:
async_tokenizer
=
AsyncMicrobatchTokenizer
(
tokenizer
)
self
.
_async_tokenizer_pool
[
tokenizer
]
=
async_tokenizer
return
async_tokenizer
async
def
_preprocess
(
self
,
...
...
@@ -467,7 +475,7 @@ class OpenAIServing:
# if _check_model has been called earlier, this will be unreachable
raise
ValueError
(
f
"The model `
{
request
.
model
}
` does not exist."
)
def
_normalize_prompt_text_to_input
(
async
def
_normalize_prompt_text_to_input
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
...
...
@@ -475,38 +483,44 @@ class OpenAIServing:
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]],
add_special_tokens
:
bool
,
)
->
TextTokensPrompt
:
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
if
(
self
.
model_config
.
encoder_config
is
not
None
and
self
.
model_config
.
encoder_config
.
get
(
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
if
truncate_prompt_tokens
is
None
:
encoded
=
tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
)
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
)
elif
truncate_prompt_tokens
<
0
:
# Negative means we cap at the model's max length
encoded
=
tokenizer
(
prompt
,
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
max_length
=
self
.
max_model_len
)
else
:
encoded
=
tokenizer
(
prompt
,
encoded
=
await
async_tokenizer
(
prompt
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
input_ids
=
encoded
.
input_ids
input_text
=
prompt
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
def
_normalize_prompt_tokens_to_input
(
async
def
_normalize_prompt_tokens_to_input
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt_ids
:
list
[
int
],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]],
)
->
TextTokensPrompt
:
async_tokenizer
=
self
.
_get_async_tokenizer
(
tokenizer
)
if
truncate_prompt_tokens
is
None
:
input_ids
=
prompt_ids
elif
truncate_prompt_tokens
<
0
:
...
...
@@ -514,7 +528,7 @@ class OpenAIServing:
else
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_text
=
tokenizer
.
decode
(
input_ids
)
input_text
=
await
async_
tokenizer
.
decode
(
input_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
...
...
@@ -578,7 +592,7 @@ class OpenAIServing:
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
def
_tokenize_prompt_input
(
async
def
_tokenize_prompt_input
_async
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
...
...
@@ -591,23 +605,24 @@ class OpenAIServing:
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input.
"""
return
next
(
self
.
_tokenize_prompt_inputs
(
async
for
result
in
self
.
_tokenize_prompt_inputs_async
(
request
,
tokenizer
,
[
prompt_input
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
,
))
):
return
result
raise
ValueError
(
"No results yielded from tokenization"
)
def
_tokenize_prompt_inputs
(
async
def
_tokenize_prompt_inputs
_async
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
prompt_inputs
:
Iterable
[
Union
[
str
,
list
[
int
]]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
,
add_special_tokens
:
bool
=
True
,
)
->
It
erator
[
TextTokensPrompt
]:
)
->
AsyncGen
erator
[
TextTokensPrompt
,
None
]:
"""
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
...
...
@@ -615,7 +630,7 @@ class OpenAIServing:
"""
for
text
in
prompt_inputs
:
if
isinstance
(
text
,
str
):
yield
self
.
_normalize_prompt_text_to_input
(
yield
await
self
.
_normalize_prompt_text_to_input
(
request
,
tokenizer
,
prompt
=
text
,
...
...
@@ -623,14 +638,14 @@ class OpenAIServing:
add_special_tokens
=
add_special_tokens
,
)
else
:
yield
self
.
_normalize_prompt_tokens_to_input
(
yield
await
self
.
_normalize_prompt_tokens_to_input
(
request
,
tokenizer
,
prompt_ids
=
text
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
)
def
_tokenize_prompt_input_or_inputs
(
async
def
_tokenize_prompt_input_or_inputs
_async
(
self
,
request
:
AnyRequest
,
tokenizer
:
AnyTokenizer
,
...
...
@@ -664,21 +679,31 @@ class OpenAIServing:
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
inputs_text
.
extend
([
self
.
_normalize_prompt_text_to_input
(
# Parse and batch the input prompts
batch_inputs
=
parse_and_batch_prompt
(
input_or_inputs
)
# Process each input in the batch concurrently
tasks
=
[]
for
prompt_input
in
batch_inputs
:
if
prompt_input
[
"is_tokens"
]
is
False
:
task
=
self
.
_normalize_prompt_text_to_input
(
request
,
tokenizer
,
prompt
=
prompt_input
[
"content"
],
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
,
add_special_tokens
=
add_special_tokens
)
if
prompt_input
[
"is_tokens"
]
is
False
else
self
.
_normalize_prompt_tokens_to_input
(
else
:
task
=
self
.
_normalize_prompt_tokens_to_input
(
request
,
tokenizer
,
prompt_ids
=
prompt_input
[
"content"
],
prompt_input
[
"content"
],
truncate_prompt_tokens
=
truncate_prompt_tokens
)
for
prompt_input
in
parse_and_batch_prompt
(
input_or_inputs
)
])
tasks
.
append
(
task
)
# Wait for all tokenization tasks to complete
results
=
await
asyncio
.
gather
(
*
tasks
)
inputs_text
.
extend
(
results
)
return
inputs_text
,
inputs_embeds
...
...
vllm/utils/__init__.py
View file @
a37d75bb
...
...
@@ -41,6 +41,7 @@ from collections import UserDict, defaultdict
from
collections.abc
import
(
AsyncGenerator
,
Awaitable
,
Collection
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
KeysView
,
Mapping
,
Sequence
)
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures.process
import
ProcessPoolExecutor
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
...
...
@@ -64,6 +65,7 @@ import zmq.asyncio
from
packaging
import
version
from
packaging.version
import
Version
from
torch.library
import
Library
from
transformers.tokenization_utils_base
import
BatchEncoding
from
typing_extensions
import
Never
,
ParamSpec
,
TypeIs
,
assert_never
import
vllm.envs
as
envs
...
...
@@ -507,6 +509,196 @@ def random_uuid() -> str:
return
str
(
uuid
.
uuid4
().
hex
)
class
AsyncMicrobatchTokenizer
:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def
__init__
(
self
,
tokenizer
,
max_batch_size
:
int
=
32
,
batch_wait_timeout_s
:
float
=
0.002
,
)
->
None
:
self
.
tokenizer
=
tokenizer
self
.
max_batch_size
=
max_batch_size
self
.
batch_wait_timeout_s
=
batch_wait_timeout_s
self
.
_loop
=
asyncio
.
get_running_loop
()
self
.
_queues
:
dict
[
tuple
,
asyncio
.
Queue
[
Union
[
tuple
[
str
,
dict
,
asyncio
.
Future
],
tuple
[
list
[
int
],
asyncio
.
Future
]]]]
=
{}
self
.
_batcher_tasks
:
list
[
asyncio
.
Task
]
=
[]
# Single-thread executor for blocking tokenizer calls.
self
.
_executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
# === Public async API ===
async
def
__call__
(
self
,
prompt
,
**
kwargs
):
result_future
:
asyncio
.
Future
=
self
.
_loop
.
create_future
()
key
=
self
.
_queue_key
(
"encode"
,
kwargs
)
queue
=
self
.
_get_queue
(
self
.
_loop
,
key
)
await
queue
.
put
((
prompt
,
kwargs
,
result_future
))
return
await
result_future
async
def
decode
(
self
,
token_ids
,
**
kwargs
):
result_future
:
asyncio
.
Future
=
self
.
_loop
.
create_future
()
key
=
self
.
_queue_key
(
"decode"
,
kwargs
)
queue
=
self
.
_get_queue
(
self
.
_loop
,
key
)
await
queue
.
put
((
token_ids
,
result_future
))
return
await
result_future
# === Internal helpers ===
def
_get_queue
(
self
,
loop
:
asyncio
.
AbstractEventLoop
,
key
:
tuple
)
->
asyncio
.
Queue
[
Union
[
tuple
[
str
,
dict
,
asyncio
.
Future
],
tuple
[
list
[
int
],
asyncio
.
Future
]]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue
=
self
.
_queues
.
get
(
key
)
if
queue
is
None
:
self
.
_queues
[
key
]
=
queue
=
asyncio
.
Queue
()
if
key
[
0
]
==
"encode"
:
can_batch
=
key
[
1
]
!=
"other"
coro
=
self
.
_batch_encode_loop
(
queue
,
can_batch
)
else
:
assert
key
[
0
]
==
"decode"
,
\
f
"Unknown operation type:
{
key
[
0
]
}
."
coro
=
self
.
_batch_decode_loop
(
queue
)
self
.
_batcher_tasks
.
append
(
loop
.
create_task
(
coro
))
return
queue
async
def
_batch_encode_loop
(
self
,
queue
:
asyncio
.
Queue
,
can_batch
:
bool
):
"""Batch incoming encode requests for efficiency."""
while
True
:
prompt
,
kwargs
,
result_future
=
await
queue
.
get
()
prompts
=
[
prompt
]
kwargs_list
=
[
kwargs
]
result_futures
=
[
result_future
]
deadline
=
self
.
_loop
.
time
()
+
self
.
batch_wait_timeout_s
while
len
(
prompts
)
<
self
.
max_batch_size
:
timeout
=
deadline
-
self
.
_loop
.
time
()
if
timeout
<=
0
:
break
try
:
prompt
,
kwargs
,
result_future
=
await
asyncio
.
wait_for
(
queue
.
get
(),
timeout
)
prompts
.
append
(
prompt
)
result_futures
.
append
(
result_future
)
if
not
can_batch
:
kwargs_list
.
append
(
kwargs
)
except
asyncio
.
TimeoutError
:
break
try
:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if
can_batch
and
len
(
prompts
)
>
1
:
encode_fn
=
partial
(
self
.
tokenizer
,
prompts
,
**
kwargs
)
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
encode_fn
)
for
i
,
fut
in
enumerate
(
result_futures
):
if
not
fut
.
done
():
data
=
{
k
:
v
[
i
]
for
k
,
v
in
results
.
items
()}
fut
.
set_result
(
BatchEncoding
(
data
))
else
:
encode_fn
=
lambda
prompts
=
prompts
,
kwargs
=
kwargs_list
:
[
self
.
tokenizer
(
p
,
**
kw
)
for
p
,
kw
in
zip
(
prompts
,
kwargs
)
]
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
encode_fn
)
for
fut
,
res
in
zip
(
result_futures
,
results
):
if
not
fut
.
done
():
fut
.
set_result
(
res
)
except
Exception
as
e
:
for
fut
in
result_futures
:
if
not
fut
.
done
():
fut
.
set_exception
(
e
)
async
def
_batch_decode_loop
(
self
,
queue
:
asyncio
.
Queue
):
"""Batch incoming decode requests for efficiency."""
while
True
:
token_ids
,
result_future
=
await
queue
.
get
()
token_ids_list
=
[
token_ids
]
result_futures
=
[
result_future
]
deadline
=
self
.
_loop
.
time
()
+
self
.
batch_wait_timeout_s
while
len
(
token_ids_list
)
<
self
.
max_batch_size
:
timeout
=
deadline
-
self
.
_loop
.
time
()
if
timeout
<=
0
:
break
try
:
token_ids
,
result_future
=
await
asyncio
.
wait_for
(
queue
.
get
(),
timeout
)
token_ids_list
.
append
(
token_ids
)
result_futures
.
append
(
result_future
)
except
asyncio
.
TimeoutError
:
break
try
:
# Perform a single batched decode call for all requests
results
=
await
self
.
_loop
.
run_in_executor
(
self
.
_executor
,
self
.
tokenizer
.
batch_decode
,
token_ids_list
)
for
fut
,
res
in
zip
(
result_futures
,
results
):
if
not
fut
.
done
():
fut
.
set_result
(
res
)
except
Exception
as
e
:
for
fut
in
result_futures
:
if
not
fut
.
done
():
fut
.
set_exception
(
e
)
def
_queue_key
(
self
,
op
:
str
,
kwargs
:
dict
)
->
tuple
:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if
op
==
"decode"
:
return
(
"decode"
,
)
add_special_tokens
=
kwargs
.
get
(
"add_special_tokens"
,
True
)
truncation
=
kwargs
.
get
(
"truncation"
,
False
)
max_length
=
kwargs
.
get
(
"max_length"
)
if
not
truncation
:
return
(
"encode"
,
add_special_tokens
,
False
,
None
)
model_max
=
getattr
(
self
.
tokenizer
,
"model_max_length"
,
None
)
if
max_length
is
None
or
(
model_max
is
not
None
and
max_length
==
model_max
):
return
(
"encode"
,
add_special_tokens
,
True
,
"model_max"
)
return
(
"encode"
,
"other"
)
def
__del__
(
self
):
for
task
in
self
.
_batcher_tasks
:
if
not
task
.
done
():
task
.
cancel
()
def
make_async
(
func
:
Callable
[
P
,
T
],
executor
:
Optional
[
concurrent
.
futures
.
Executor
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment