Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
fda022b1
Unverified
Commit
fda022b1
authored
Mar 11, 2026
by
Graham King
Committed by
GitHub
Mar 11, 2026
Browse files
test(frontend): Minimal integration test for vllm processor (#7173)
Signed-off-by:
Graham King
<
grahamk@nvidia.com
>
parent
96a55928
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
366 additions
and
294 deletions
+366
-294
components/src/dynamo/frontend/vllm_processor.py
components/src/dynamo/frontend/vllm_processor.py
+2
-294
tests/frontend/test_vllm_prepost_integration.py
tests/frontend/test_vllm_prepost_integration.py
+277
-0
tests/frontend/vllm_prepost_worker.py
tests/frontend/vllm_prepost_worker.py
+87
-0
No files found.
components/src/dynamo/frontend/vllm_processor.py
View file @
fda022b1
...
...
@@ -11,9 +11,6 @@ import os
import
time
from
argparse
import
Namespace
from
collections.abc
import
AsyncGenerator
from
concurrent.futures
import
ProcessPoolExecutor
from
concurrent.futures
import
wait
as
_futures_wait
from
dataclasses
import
dataclass
from
typing
import
Any
from
vllm.config
import
CacheConfig
,
LoadConfig
,
ModelConfig
,
VllmConfig
...
...
@@ -38,12 +35,8 @@ from dynamo.llm import (
)
from
dynamo.runtime
import
Client
,
DistributedRuntime
from
.prepost
import
(
StreamingPostProcessor
,
preprocess_chat_request
,
preprocess_chat_request_sync
,
)
from
.utils
import
PreprocessError
,
random_uuid
,
worker_warmup
from
.prepost
import
StreamingPostProcessor
,
preprocess_chat_request
from
.utils
import
random_uuid
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -74,181 +67,6 @@ def map_finish_reason(raw_reason: str | None) -> FinishReason | None:
return
mapped
# --- Worker process globals (initialized once per process by _init_worker) ---
_w_input_processor
:
InputProcessor
|
None
=
None
_w_tokenizer
:
Any
=
None
_w_tool_parser_class
:
type
[
ToolParser
]
|
None
=
None
@
dataclass
class
PreprocessWorkerResult
:
"""Picklable return value from the preprocess worker."""
dynamo_preproc
:
dict
[
str
,
Any
]
tokens
:
list
[
int
]
vllm_preproc
:
EngineCoreRequest
sampling_params
:
SamplingParams
request_for_sampling
:
Any
# ChatCompletionRequest (Pydantic model, picklable)
chat_template_kwargs
:
dict
[
str
,
Any
]
def
_init_worker
(
model_path
:
str
,
tokenizer_mode
:
str
,
config_format
:
str
,
load_format
:
str
,
tool_parser_name
:
str
|
None
,
)
->
None
:
"""Initialize a worker process with its own VllmConfig and InputProcessor."""
global
_w_input_processor
,
_w_tokenizer
,
_w_tool_parser_class
global
_w_reasoning_parser_class
model_config
=
ModelConfig
(
model
=
model_path
,
tokenizer_mode
=
tokenizer_mode
,
config_format
=
config_format
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
load_config
=
LoadConfig
(
load_format
=
load_format
),
cache_config
=
CacheConfig
(),
)
_w_input_processor
=
InputProcessor
(
vllm_config
)
_w_tokenizer
=
_w_input_processor
.
get_tokenizer
()
if
tool_parser_name
:
_w_tool_parser_class
=
ToolParserManager
.
get_tool_parser
(
tool_parser_name
)
else
:
_w_tool_parser_class
=
None
def
_preprocess_worker
(
request
:
dict
[
str
,
Any
],
request_id
:
str
,
model_name
:
str
,
)
->
PreprocessWorkerResult
:
"""Preprocess a request in a worker process and return a picklable result."""
assert
_w_input_processor
is
not
None
pre
=
preprocess_chat_request_sync
(
request
,
tokenizer
=
_w_tokenizer
,
renderer
=
_w_input_processor
.
renderer
,
tool_parser_class
=
_w_tool_parser_class
,
)
request_for_sampling
=
pre
.
request_for_sampling
engine_prompt
=
pre
.
engine_prompt
tokens
=
pre
.
prompt_token_ids
if
request_for_sampling
.
max_completion_tokens
is
not
None
:
max_tokens
=
request_for_sampling
.
max_completion_tokens
elif
request_for_sampling
.
max_tokens
is
not
None
:
max_tokens
=
request_for_sampling
.
max_tokens
else
:
max_tokens
=
None
sampling_params
=
SamplingParams
(
output_kind
=
RequestOutputKind
.
DELTA
,
max_tokens
=
max_tokens
,
)
for
k
,
v
in
_w_input_processor
.
generation_config_fields
.
items
():
if
hasattr
(
sampling_params
,
k
):
setattr
(
sampling_params
,
k
,
v
)
sampling_fields
=
(
set
(
getattr
(
SamplingParams
,
"__annotations__"
,
()))
&
set
(
type
(
request_for_sampling
).
model_fields
)
)
-
{
"max_tokens"
,
"logprobs"
,
"output_kind"
}
for
k
in
sorted
(
sampling_fields
):
v
=
getattr
(
request_for_sampling
,
k
,
None
)
if
v
is
not
None
:
setattr
(
sampling_params
,
k
,
v
)
logprobs
=
request_for_sampling
.
logprobs
top_logprobs
=
request_for_sampling
.
top_logprobs
if
logprobs
is
True
:
sampling_params
.
logprobs
=
top_logprobs
or
1
elif
isinstance
(
logprobs
,
int
)
and
not
isinstance
(
logprobs
,
bool
):
sampling_params
.
logprobs
=
logprobs
elif
top_logprobs
not
in
(
None
,
0
):
sampling_params
.
logprobs
=
top_logprobs
prompt_inputs
=
TokensPrompt
(
prompt_token_ids
=
tokens
)
if
"multi_modal_data"
in
engine_prompt
:
prompt_inputs
[
"multi_modal_data"
]
=
engine_prompt
[
"multi_modal_data"
]
if
"multi_modal_uuids"
in
engine_prompt
:
prompt_inputs
[
"multi_modal_uuids"
]
=
engine_prompt
[
"multi_modal_uuids"
]
if
request_for_sampling
.
cache_salt
is
not
None
:
prompt_inputs
[
"cache_salt"
]
=
request_for_sampling
.
cache_salt
if
request_for_sampling
.
mm_processor_kwargs
is
not
None
:
prompt_inputs
[
"mm_processor_kwargs"
]
=
request_for_sampling
.
mm_processor_kwargs
vllm_preproc
:
EngineCoreRequest
=
_w_input_processor
.
process_inputs
(
request_id
,
prompt_inputs
,
sampling_params
,
)
InputProcessor
.
assign_request_id
(
vllm_preproc
)
sp
=
vllm_preproc
.
sampling_params
if
sp
.
n
!=
1
:
raise
PreprocessError
(
{
"error"
:
{
"message"
:
(
f
"Unsupported value: 'n=
{
sp
.
n
}
'. "
"This endpoint currently supports only n=1."
),
"type"
:
"invalid_request_error"
,
"param"
:
"n"
,
"code"
:
"unsupported_value"
,
}
}
)
dynamo_preproc
=
{
"model"
:
model_name
,
"token_ids"
:
tokens
,
"stop_conditions"
:
{
"max_tokens"
:
sp
.
max_tokens
,
"stop"
:
sp
.
stop
,
"stop_token_ids"
:
sp
.
stop_token_ids
,
"min_tokens"
:
sp
.
min_tokens
,
"ignore_eos"
:
sp
.
ignore_eos
,
},
"sampling_options"
:
{
"n"
:
sp
.
n
,
"presence_penalty"
:
sp
.
presence_penalty
,
"frequency_penalty"
:
sp
.
frequency_penalty
,
"repetition_penalty"
:
sp
.
repetition_penalty
,
"temperature"
:
sp
.
temperature
,
"top_p"
:
sp
.
top_p
,
"top_k"
:
sp
.
top_k
,
"min_p"
:
sp
.
min_p
,
"seed"
:
sp
.
seed
,
},
"output_options"
:
{
"logprobs"
:
sp
.
logprobs
,
"prompt_logprobs"
:
sp
.
prompt_logprobs
,
"skip_special_tokens"
:
sp
.
skip_special_tokens
,
},
"eos_token_ids"
:
(
[
vllm_preproc
.
eos_token_id
]
if
vllm_preproc
.
eos_token_id
is
not
None
else
[]
),
"annotations"
:
[],
}
return
PreprocessWorkerResult
(
dynamo_preproc
=
dynamo_preproc
,
tokens
=
tokens
,
vllm_preproc
=
vllm_preproc
,
sampling_params
=
sampling_params
,
request_for_sampling
=
request_for_sampling
,
chat_template_kwargs
=
pre
.
chat_template_kwargs
,
)
class
VllmProcessor
:
def
__init__
(
self
,
...
...
@@ -526,77 +344,6 @@ class VllmProcessor:
[
vllm_preproc
.
request_id
],
internal
=
True
)
async
def
_generator_inner_pool
(
self
,
request
:
dict
[
str
,
Any
]
)
->
AsyncGenerator
[
dict
[
str
,
Any
],
None
]:
"""Process a request using the worker pool.
Phase 1: Preprocess in a worker process (semaphore held).
Phase 2: Remote inference via router (no worker held).
Phase 3: Post-process tokens in the main process.
"""
request_id
=
random_uuid
()
# --- Phase 1: Preprocess (semaphore held) ---
try
:
assert
self
.
_worker_semaphore
is
not
None
async
with
self
.
_worker_semaphore
:
assert
self
.
preprocess_pool
is
not
None
future
=
self
.
preprocess_pool
.
submit
(
_preprocess_worker
,
request
,
request_id
,
request
[
"model"
]
)
preproc_result
:
PreprocessWorkerResult
=
await
asyncio
.
wrap_future
(
future
)
# Semaphore + worker released here
except
PreprocessError
as
exc
:
yield
exc
.
error_dict
return
except
Exception
as
exc
:
logger
.
exception
(
"Worker preprocessing failed for request %s"
,
request_id
)
yield
{
"error"
:
{
"message"
:
f
"Worker error:
{
exc
}
"
,
"type"
:
"internal_error"
,
}
}
return
# --- Between phases: reconstruct main-process objects ---
dynamo_preproc
=
preproc_result
.
dynamo_preproc
tokens
=
preproc_result
.
tokens
vllm_preproc
=
preproc_result
.
vllm_preproc
sampling_params
=
preproc_result
.
sampling_params
request_for_sampling
=
preproc_result
.
request_for_sampling
tool_parser
=
None
if
(
self
.
tool_parser_class
and
request_for_sampling
.
tools
and
request_for_sampling
.
tool_choice
!=
"none"
):
tool_parser
=
self
.
tool_parser_class
(
self
.
tokenizer
)
post
=
StreamingPostProcessor
(
tokenizer
=
self
.
tokenizer
,
request_for_sampling
=
request_for_sampling
,
sampling_params
=
sampling_params
,
prompt_token_ids
=
tokens
,
tool_parser
=
tool_parser
,
reasoning_parser_class
=
self
.
reasoning_parser_class
,
chat_template_kwargs
=
preproc_result
.
chat_template_kwargs
,
)
async
for
item
in
self
.
_generate_and_stream
(
request_id
,
request
,
dynamo_preproc
,
tokens
,
vllm_preproc
,
post
,
):
yield
item
class
EngineFactory
:
def
__init__
(
...
...
@@ -705,45 +452,6 @@ class EngineFactory:
router_mode
=
self
.
router_config
.
router_mode
)
preprocess_pool
=
None
preprocess_workers
=
self
.
config
.
preprocess_workers
if
preprocess_workers
>
0
:
logger
.
info
(
"Creating preprocess worker pool with %d workers for model %s"
,
preprocess_workers
,
source_path
,
)
preprocess_pool
=
ProcessPoolExecutor
(
max_workers
=
preprocess_workers
,
initializer
=
_init_worker
,
initargs
=
(
source_path
,
tokenizer_mode
,
config_format
,
load_format
,
tool_parser_name
,
),
)
# Warm up all workers to ensure initialization completes
futures
=
[
preprocess_pool
.
submit
(
worker_warmup
)
for
_
in
range
(
preprocess_workers
)
]
done
,
not_done
=
_futures_wait
(
futures
,
timeout
=
120
)
if
not_done
:
for
f
in
not_done
:
f
.
cancel
()
preprocess_pool
.
shutdown
(
wait
=
False
,
cancel_futures
=
True
)
raise
RuntimeError
(
"Timed out waiting for preprocess worker pool warmup"
)
try
:
for
f
in
done
:
f
.
result
()
# Raises if initializer failed
except
Exception
:
preprocess_pool
.
shutdown
(
wait
=
False
,
cancel_futures
=
True
)
raise
logger
.
info
(
"Preprocess worker pool ready (%d workers)"
,
preprocess_workers
)
gen
=
VllmProcessor
(
tokenizer
,
input_processor
,
...
...
tests/frontend/test_vllm_prepost_integration.py
0 → 100644
View file @
fda022b1
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
json
import
logging
import
os
import
time
from
pathlib
import
Path
from
typing
import
Any
,
Generator
import
pytest
import
requests
from
tests.utils.constants
import
QWEN
from
tests.utils.managed_process
import
DynamoFrontendProcess
,
ManagedProcess
from
tests.utils.port_utils
import
ServicePorts
logger
=
logging
.
getLogger
(
__name__
)
TEST_MODEL
=
QWEN
CAPTURE_PATH_ENV
=
"DYN_VLLM_PREPOST_CAPTURE_PATH"
SEARCH_TOOL
=
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"search_gutenberg_books"
,
"description"
:
"Search for books in the Project Gutenberg library"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"search_terms"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"string"
},
"description"
:
"List of search terms to find books"
,
}
},
"required"
:
[
"search_terms"
],
},
},
}
pytestmark
=
[
pytest
.
mark
.
vllm
,
# vllm frontend doesn't need or use the GPU, but in CI pytorch seems to look for the Device
pytest
.
mark
.
gpu_1
,
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
integration
,
pytest
.
mark
.
parallel
,
pytest
.
mark
.
model
(
TEST_MODEL
),
]
class
MockVllmPrepostWorkerProcess
(
ManagedProcess
):
"""Test worker that captures frontend tokenized requests."""
def
__init__
(
self
,
request
,
*
,
frontend_port
:
int
,
capture_path
:
Path
,
worker_id
:
str
=
"vllm-prepost-worker"
,
)
->
None
:
env
=
os
.
environ
.
copy
()
env
[
CAPTURE_PATH_ENV
]
=
str
(
capture_path
)
super
().
__init__
(
command
=
[
"python3"
,
"-m"
,
"tests.frontend.vllm_prepost_worker"
],
env
=
env
,
health_check_urls
=
[
(
f
"http://localhost:
{
frontend_port
}
/v1/models"
,
self
.
_check_models_api
,
)
],
timeout
=
60
,
display_output
=
True
,
terminate_all_matching_process_names
=
False
,
straggler_commands
=
[
"-m tests.frontend.vllm_prepost_worker"
],
log_dir
=
f
"
{
request
.
node
.
name
}
_
{
worker_id
}
"
,
)
@
staticmethod
def
_check_models_api
(
response
:
requests
.
Response
)
->
bool
:
try
:
if
response
.
status_code
!=
200
:
return
False
data
=
response
.
json
()
except
(
ValueError
,
KeyError
):
return
False
for
model
in
data
.
get
(
"data"
,
[]):
if
model
.
get
(
"id"
)
==
TEST_MODEL
:
return
True
return
False
def
_read_captured_request
(
path
:
Path
,
timeout_s
:
float
=
20.0
)
->
dict
[
str
,
Any
]:
deadline
=
time
.
time
()
+
timeout_s
while
time
.
time
()
<
deadline
:
if
path
.
exists
():
return
json
.
loads
(
path
.
read_text
(
encoding
=
"utf-8"
))
time
.
sleep
(
0.1
)
raise
AssertionError
(
f
"Timed out waiting for captured request at
{
path
}
"
)
def
_collect_stream_chunks
(
response
:
requests
.
Response
)
->
list
[
dict
[
str
,
Any
]]:
response
.
raise_for_status
()
chunks
:
list
[
dict
[
str
,
Any
]]
=
[]
saw_done
=
False
for
line
in
response
.
iter_lines
(
decode_unicode
=
True
):
if
not
line
:
continue
assert
line
.
startswith
(
"data: "
),
f
"Unexpected SSE line:
{
line
!
r
}
"
payload
=
line
[
len
(
"data: "
)
:]
if
payload
==
"[DONE]"
:
saw_done
=
True
break
chunks
.
append
(
json
.
loads
(
payload
))
assert
saw_done
,
"Missing [DONE] marker in SSE stream"
assert
chunks
,
"Expected streamed chunks but got none"
return
chunks
def
_collect_reasoning
(
chunks
:
list
[
dict
[
str
,
Any
]])
->
str
:
parts
:
list
[
str
]
=
[]
for
chunk
in
chunks
:
for
choice
in
chunk
.
get
(
"choices"
,
[]):
reasoning
=
(
choice
.
get
(
"delta"
)
or
{}).
get
(
"reasoning_content"
)
if
reasoning
is
not
None
:
parts
.
append
(
reasoning
)
return
""
.
join
(
parts
)
def
_collect_tool_calls
(
chunks
:
list
[
dict
[
str
,
Any
]])
->
list
[
dict
[
str
,
Any
]]:
merged
:
dict
[
int
,
dict
[
str
,
Any
]]
=
{}
for
chunk
in
chunks
:
for
choice
in
chunk
.
get
(
"choices"
,
[]):
for
tool_call
in
(
choice
.
get
(
"delta"
)
or
{}).
get
(
"tool_calls"
)
or
[]:
idx
=
tool_call
[
"index"
]
if
idx
not
in
merged
:
merged
[
idx
]
=
{
"id"
:
tool_call
.
get
(
"id"
),
"type"
:
tool_call
.
get
(
"type"
),
"function"
:
{
"name"
:
tool_call
.
get
(
"function"
,
{}).
get
(
"name"
),
"arguments"
:
tool_call
.
get
(
"function"
,
{}).
get
(
"arguments"
,
""
),
},
}
continue
existing
=
merged
[
idx
]
if
tool_call
.
get
(
"id"
)
and
not
existing
[
"id"
]:
existing
[
"id"
]
=
tool_call
[
"id"
]
if
tool_call
.
get
(
"type"
)
and
not
existing
[
"type"
]:
existing
[
"type"
]
=
tool_call
[
"type"
]
incoming_fn
=
tool_call
.
get
(
"function"
,
{})
if
incoming_fn
.
get
(
"name"
)
and
not
existing
[
"function"
][
"name"
]:
existing
[
"function"
][
"name"
]
=
incoming_fn
[
"name"
]
if
incoming_fn
.
get
(
"arguments"
):
existing
[
"function"
][
"arguments"
]
+=
incoming_fn
[
"arguments"
]
return
[
merged
[
idx
]
for
idx
in
sorted
(
merged
)]
@
pytest
.
fixture
(
scope
=
"function"
)
def
start_services
(
request
,
runtime_services_dynamic_ports
,
dynamo_dynamic_ports
:
ServicePorts
,
tmp_path
:
Path
,
)
->
Generator
[
tuple
[
int
,
Path
],
None
,
None
]:
_
=
runtime_services_dynamic_ports
frontend_port
=
dynamo_dynamic_ports
.
frontend_port
capture_path
=
tmp_path
/
"captured_request.json"
with
DynamoFrontendProcess
(
request
,
frontend_port
=
frontend_port
,
extra_args
=
[
"--dyn-chat-processor"
,
"vllm"
,
"--discovery-backend"
,
"etcd"
,
# Started by the fixture
"--request-plane"
,
"tcp"
,
"--enable-auto-tool-choice"
,
"--tool-call-parser"
,
"hermes"
,
"--reasoning-parser"
,
"qwen3"
,
],
extra_env
=
{
"DYN_VLLM_STREAM_INTERVAL"
:
"20"
},
terminate_all_matching_process_names
=
False
,
):
logger
.
info
(
"Frontend started on port %s"
,
frontend_port
)
with
MockVllmPrepostWorkerProcess
(
request
,
frontend_port
=
frontend_port
,
capture_path
=
capture_path
,
):
logger
.
info
(
"vLLM pre/post test worker registered model %s"
,
TEST_MODEL
)
yield
frontend_port
,
capture_path
@
pytest
.
mark
.
timeout
(
120
)
def
test_vllm_chat_processor_tokenizes_and_streams_tool_calls
(
start_services
:
tuple
[
int
,
Path
],
)
->
None
:
frontend_port
,
capture_path
=
start_services
payload
=
{
"model"
:
TEST_MODEL
,
"messages"
:
[
{
"role"
:
"user"
,
"content"
:
"What are the titles of some James Joyce books? Use the tool to search."
,
}
],
"tools"
:
[
SEARCH_TOOL
],
"tool_choice"
:
"auto"
,
"stream"
:
True
,
"max_tokens"
:
128
,
}
response
=
requests
.
post
(
f
"http://localhost:
{
frontend_port
}
/v1/chat/completions"
,
json
=
payload
,
timeout
=
60
,
stream
=
True
,
)
chunks
=
_collect_stream_chunks
(
response
)
captured
=
_read_captured_request
(
capture_path
)
assert
captured
[
"model"
]
==
TEST_MODEL
assert
isinstance
(
captured
[
"token_ids"
],
list
)
and
captured
[
"token_ids"
]
decoded_prompt
=
captured
[
"decoded_prompt"
]
assert
"What are the titles of some James Joyce books?"
in
decoded_prompt
assert
"search_gutenberg_books"
in
decoded_prompt
reasoning
=
_collect_reasoning
(
chunks
)
assert
"titles of some James Joyce books"
in
reasoning
tool_calls
=
_collect_tool_calls
(
chunks
)
assert
len
(
tool_calls
)
==
1
tool_call
=
tool_calls
[
0
]
assert
tool_call
[
"function"
][
"name"
]
==
"search_gutenberg_books"
assert
json
.
loads
(
tool_call
[
"function"
][
"arguments"
])
==
{
"search_terms"
:
[
"James Joyce"
,
"Project Gutenberg"
],
}
content
=
""
.
join
(
(
choice
.
get
(
"delta"
)
or
{}).
get
(
"content"
)
or
""
for
chunk
in
chunks
for
choice
in
chunk
.
get
(
"choices"
,
[])
)
assert
"<tool_call>"
not
in
content
assert
"</tool_call>"
not
in
content
finish_reasons
=
[
choice
.
get
(
"finish_reason"
)
for
chunk
in
chunks
for
choice
in
chunk
.
get
(
"choices"
,
[])
if
choice
.
get
(
"finish_reason"
)
]
assert
finish_reasons
,
"Expected at least one finish_reason"
assert
set
(
finish_reasons
)
<=
{
"stop"
,
"tool_calls"
}
tests/frontend/vllm_prepost_worker.py
0 → 100644
View file @
fda022b1
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Lightweight token-based worker for vLLM frontend pre/post integration tests."""
from
__future__
import
annotations
import
asyncio
import
json
import
os
from
pathlib
import
Path
from
typing
import
Any
import
uvloop
from
transformers
import
AutoTokenizer
from
dynamo.llm
import
ModelInput
,
ModelType
,
register_model
from
dynamo.runtime
import
DistributedRuntime
from
tests.frontend.test_prepost
import
OUTPUTS_INTERVAL_20
from
tests.frontend.test_vllm_prepost_integration
import
CAPTURE_PATH_ENV
from
tests.utils.constants
import
QWEN
class
VllmPrepostTestHandler
:
"""Captures tokenized requests and streams a fixed token response."""
def
__init__
(
self
,
model_name
:
str
=
QWEN
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
def
_write_capture
(
self
,
request
:
dict
[
str
,
Any
])
->
None
:
capture_path
=
os
.
environ
.
get
(
CAPTURE_PATH_ENV
)
if
not
capture_path
:
return
token_ids
=
request
.
get
(
"token_ids"
,
[])
captured
=
{
"model"
:
request
.
get
(
"model"
),
"token_ids"
:
token_ids
,
"stop_conditions"
:
request
.
get
(
"stop_conditions"
),
"sampling_options"
:
request
.
get
(
"sampling_options"
),
"output_options"
:
request
.
get
(
"output_options"
),
"eos_token_ids"
:
request
.
get
(
"eos_token_ids"
),
"decoded_prompt"
:
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
False
,
),
}
path
=
Path
(
capture_path
)
path
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
tmp_path
=
path
.
with_suffix
(
path
.
suffix
+
".tmp"
)
tmp_path
.
write_text
(
json
.
dumps
(
captured
),
encoding
=
"utf-8"
)
tmp_path
.
replace
(
path
)
async
def
generate
(
self
,
request
:
dict
[
str
,
Any
],
context
):
self
.
_write_capture
(
request
)
for
output
in
OUTPUTS_INTERVAL_20
:
chunk
=
{
"token_ids"
:
list
(
output
.
token_ids
)}
if
output
.
finish_reason
is
not
None
:
chunk
[
"finish_reason"
]
=
output
.
finish_reason
if
output
.
stop_reason
is
not
None
:
chunk
[
"stop_reason"
]
=
output
.
stop_reason
yield
chunk
async
def
main
():
"""Register a token-based chat model and stream deterministic responses."""
runtime
=
DistributedRuntime
(
asyncio
.
get_running_loop
(),
"etcd"
,
"tcp"
,
enable_nats
=
False
)
endpoint
=
runtime
.
endpoint
(
"test.vllm-prepost.generate"
)
await
register_model
(
ModelInput
.
Tokens
,
ModelType
.
Chat
,
endpoint
,
QWEN
,
model_name
=
QWEN
,
)
handler
=
VllmPrepostTestHandler
(
QWEN
)
await
endpoint
.
serve_endpoint
(
handler
.
generate
)
if
__name__
==
"__main__"
:
uvloop
.
run
(
main
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment