Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
ab33729b
Commit
ab33729b
authored
Mar 11, 2025
by
Neelay Shah
Committed by
GitHub
Mar 11, 2025
Browse files
fix: Add missing util files to vllm example (#105)
parent
b0655a34
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
594 additions
and
0 deletions
+594
-0
examples/python_rs/llm/vllm/utils/chat_processor.py
examples/python_rs/llm/vllm/utils/chat_processor.py
+240
-0
examples/python_rs/llm/vllm/utils/nats_queue.py
examples/python_rs/llm/vllm/utils/nats_queue.py
+142
-0
examples/python_rs/llm/vllm/utils/nixl.py
examples/python_rs/llm/vllm/utils/nixl.py
+105
-0
examples/python_rs/llm/vllm/utils/prefill_queue.py
examples/python_rs/llm/vllm/utils/prefill_queue.py
+56
-0
examples/python_rs/llm/vllm/utils/vllm.py
examples/python_rs/llm/vllm/utils/vllm.py
+51
-0
No files found.
examples/python_rs/llm/vllm/utils/chat_processor.py
0 → 100644
View file @
ab33729b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
time
from
typing
import
AsyncIterator
,
List
,
Optional
,
Protocol
,
Union
,
runtime_checkable
from
vllm.config
import
ModelConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.chat_utils
import
ConversationMessage
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
CompletionRequest
,
RequestResponseMetadata
,
)
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_engine
import
RequestPrompt
from
vllm.inputs.data
import
TokensPrompt
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
@
runtime_checkable
class
ProcessMixInRequired
(
Protocol
):
engine_args
:
AsyncEngineArgs
chat_processor
:
"ChatProcessor | None"
completions_processor
:
"CompletionsProcessor | None"
model_config
:
ModelConfig
class
ProcessMixIn
(
ProcessMixInRequired
):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args
:
AsyncEngineArgs
chat_processor
:
"ChatProcessor | None"
completions_processor
:
"CompletionsProcessor | None"
model_config
:
ModelConfig
def
__init__
(
self
):
pass
def
_get_processor
(
self
,
raw_request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
):
# Determine the processor type based on the request structure
return
(
self
.
chat_processor
if
isinstance
(
raw_request
,
ChatCompletionRequest
)
else
self
.
completions_processor
)
async
def
_parse_raw_request
(
self
,
raw_request
:
Union
[
CompletionRequest
,
ChatCompletionRequest
]
):
processor
=
self
.
_get_processor
(
raw_request
)
if
processor
is
None
:
raise
RuntimeError
(
"Processor has not been initialized"
)
request
=
processor
.
parse_raw_request
(
raw_request
)
preprocess_result
=
await
processor
.
preprocess
(
raw_request
)
default_max_tokens
=
self
.
model_config
.
max_model_len
-
len
(
preprocess_result
.
engine_prompt
[
"prompt_token_ids"
]
)
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
,
self
.
model_config
.
logits_processor_pattern
,
default_sampling_params
,
)
return
(
request
,
preprocess_result
.
conversation
,
preprocess_result
.
request_prompt
,
preprocess_result
.
engine_prompt
,
sampling_params
,
)
async
def
_stream_response
(
self
,
request
,
generator
,
request_id
,
conversation
):
processor
=
self
.
_get_processor
(
request
)
if
processor
is
None
:
raise
RuntimeError
(
"processor has not been initialized"
)
return
processor
.
stream_response
(
request
,
generator
,
request_id
,
conversation
,
)
class
PreprocessResult
:
def
__init__
(
self
,
conversation
:
Optional
[
ConversationMessage
],
request_prompt
:
RequestPrompt
,
engine_prompt
:
TokensPrompt
,
):
self
.
conversation
=
conversation
self
.
request_prompt
=
request_prompt
self
.
engine_prompt
=
engine_prompt
class
ChatProcessor
:
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
,
model_config
:
ModelConfig
):
self
.
tokenizer
=
tokenizer
self
.
model_config
=
model_config
self
.
openai_serving
=
OpenAIServingChat
(
engine_client
=
None
,
model_config
=
model_config
,
models
=
None
,
request_logger
=
None
,
response_role
=
"assistant"
,
chat_template
=
None
,
chat_template_content_format
=
"auto"
,
)
def
parse_raw_request
(
self
,
raw_request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
return
ChatCompletionRequest
.
parse_obj
(
raw_request
)
async
def
preprocess
(
self
,
raw_request
:
ChatCompletionRequest
)
->
PreprocessResult
:
request
=
self
.
parse_raw_request
(
raw_request
)
(
conversation
,
request_prompts
,
engine_prompts
,
)
=
await
self
.
openai_serving
.
_preprocess_chat
(
request
,
self
.
tokenizer
,
request
.
messages
,
chat_template
=
request
.
chat_template
or
self
.
tokenizer
.
chat_template
,
chat_template_content_format
=
self
.
openai_serving
.
chat_template_content_format
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
tool_dicts
=
None
,
documents
=
request
.
documents
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
tool_parser
=
self
.
openai_serving
.
tool_parser
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
return
PreprocessResult
(
conversation
[
0
],
request_prompts
[
0
],
engine_prompts
[
0
])
async
def
stream_response
(
self
,
request
:
ChatCompletionRequest
,
result_generator
:
AsyncIterator
,
request_id
:
str
,
conversation
:
List
,
):
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
not
request
.
stream
:
raise
ValueError
(
"Only streaming responses are supported"
)
async
for
raw_response
in
self
.
openai_serving
.
chat_completion_stream_generator
(
request
,
result_generator
,
request_id
,
request
.
model
,
conversation
,
self
.
tokenizer
,
request_metadata
,
):
if
raw_response
.
startswith
(
"data: [DONE]"
):
break
response
=
json
.
loads
(
raw_response
.
lstrip
(
"data: "
))
yield
response
class
CompletionsProcessor
:
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
,
model_config
:
ModelConfig
):
self
.
tokenizer
=
tokenizer
self
.
model_config
=
model_config
self
.
openai_serving
=
OpenAIServingCompletion
(
engine_client
=
None
,
model_config
=
model_config
,
models
=
None
,
request_logger
=
None
,
)
def
parse_raw_request
(
self
,
raw_request
:
CompletionRequest
)
->
CompletionRequest
:
return
CompletionRequest
.
parse_obj
(
raw_request
)
async
def
preprocess
(
self
,
raw_request
:
CompletionRequest
)
->
PreprocessResult
:
request
=
self
.
parse_raw_request
(
raw_request
)
(
request_prompts
,
engine_prompts
,
)
=
await
self
.
openai_serving
.
_preprocess_completion
(
request
,
self
.
tokenizer
,
input_or_inputs
=
request
.
prompt
,
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
,
add_special_tokens
=
request
.
add_special_tokens
,
)
return
PreprocessResult
(
None
,
request_prompts
[
0
],
engine_prompts
[
0
])
async
def
stream_response
(
self
,
request
:
CompletionRequest
,
result_generator
:
AsyncIterator
,
request_id
:
str
,
conversation
:
Optional
[
List
[
ConversationMessage
]]
=
None
,
):
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
not
request
.
stream
:
raise
ValueError
(
"Only streaming responses are supported"
)
async
for
raw_response
in
self
.
openai_serving
.
completion_stream_generator
(
request
,
result_generator
,
request_id
,
int
(
time
.
time
()),
# created_time
request
.
model
,
1
,
# num_prompts
self
.
tokenizer
,
request_metadata
,
):
if
raw_response
.
startswith
(
"data: [DONE]"
):
break
response
=
json
.
loads
(
raw_response
.
lstrip
(
"data: "
))
yield
response
examples/python_rs/llm/vllm/utils/nats_queue.py
0 → 100644
View file @
ab33729b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
from
contextlib
import
asynccontextmanager
from
typing
import
ClassVar
,
Optional
from
nats.aio.client
import
Client
as
NATS
from
nats.errors
import
Error
as
NatsError
from
nats.js.client
import
JetStreamContext
from
nats.js.errors
import
NotFoundError
class
NATSQueue
:
_instance
:
ClassVar
[
Optional
[
"NATSQueue"
]]
=
None
_lock
:
ClassVar
[
asyncio
.
Lock
]
=
asyncio
.
Lock
()
def
__init__
(
self
,
stream_name
:
str
=
"default"
,
nats_server
:
str
=
"nats://localhost:4222"
,
dequeue_timeout
:
float
=
1
,
):
self
.
nats_url
=
nats_server
self
.
_nc
:
Optional
[
NATS
]
=
None
self
.
_js
:
Optional
[
JetStreamContext
]
=
None
# TODO: check if this is needed
# Sanitize stream_name to remove path separators
self
.
_stream_name
=
stream_name
.
replace
(
"/"
,
"_"
).
replace
(
"
\\
"
,
"_"
)
self
.
_subject
=
f
"
{
self
.
_stream_name
}
.*"
self
.
dequeue_timeout
=
dequeue_timeout
self
.
_subscriber
:
Optional
[
JetStreamContext
.
PullSubscription
]
=
None
@
classmethod
@
asynccontextmanager
async
def
get_instance
(
cls
,
*
,
stream_name
:
str
=
"default"
,
nats_server
:
str
=
"nats://localhost:4222"
,
dequeue_timeout
:
float
=
1
,
):
"""Get or create a singleton instance of NATSq"""
# TODO: check if this _lock is needed with GIL
async
with
cls
.
_lock
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
(
stream_name
=
stream_name
,
nats_server
=
nats_server
,
dequeue_timeout
=
dequeue_timeout
,
)
await
cls
.
_instance
.
connect
()
try
:
yield
cls
.
_instance
except
Exception
:
if
cls
.
_instance
:
await
cls
.
_instance
.
close
()
cls
.
_instance
=
None
raise
# TODO: check to see if this can be replaced by something like get_instance().close()
@
classmethod
async
def
shutdown
(
cls
):
"""Explicitly close the singleton instance if it exists"""
async
with
cls
.
_lock
:
if
cls
.
_instance
:
await
cls
.
_instance
.
close
()
cls
.
_instance
=
None
async
def
connect
(
self
):
"""Establish connection and create stream if needed"""
try
:
if
self
.
_nc
is
None
:
self
.
_nc
=
NATS
()
await
self
.
_nc
.
connect
(
self
.
nats_url
)
self
.
_js
=
self
.
_nc
.
jetstream
()
# Check if stream exists, if not create it
try
:
await
self
.
_js
.
stream_info
(
self
.
_stream_name
)
except
NotFoundError
:
await
self
.
_js
.
add_stream
(
name
=
self
.
_stream_name
,
subjects
=
[
self
.
_subject
]
)
# Create persistent subscriber
self
.
_subscriber
=
await
self
.
_js
.
pull_subscribe
(
f
"
{
self
.
_stream_name
}
.queue"
,
durable
=
"worker-group"
)
except
NatsError
as
e
:
await
self
.
close
()
raise
ConnectionError
(
f
"Failed to connect to NATS:
{
e
}
"
)
async
def
ensure_connection
(
self
):
"""Ensure we have an active connection"""
if
self
.
_nc
is
None
or
self
.
_nc
.
is_closed
:
await
self
.
connect
()
async
def
close
(
self
):
"""Close the connection when done"""
if
self
.
_nc
:
await
self
.
_nc
.
close
()
self
.
_nc
=
None
self
.
_js
=
None
self
.
_subscriber
=
None
# TODO: is enqueue/dequeue_object a better name for a general queue?
async
def
enqueue_task
(
self
,
task_data
:
bytes
)
->
None
:
"""
Enqueue a task using msgspec-encoded data
"""
await
self
.
ensure_connection
()
try
:
await
self
.
_js
.
publish
(
f
"
{
self
.
_stream_name
}
.queue"
,
task_data
)
# type: ignore
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to enqueue task:
{
e
}
"
)
async
def
dequeue_task
(
self
)
->
Optional
[
bytes
]:
"""Dequeue and return a task as raw bytes, to be decoded with msgspec"""
await
self
.
ensure_connection
()
try
:
msgs
=
await
self
.
_subscriber
.
fetch
(
1
,
timeout
=
self
.
dequeue_timeout
)
# type: ignore
if
msgs
:
msg
=
msgs
[
0
]
await
msg
.
ack
()
return
msg
.
data
return
None
except
asyncio
.
TimeoutError
:
return
None
except
NatsError
as
e
:
raise
RuntimeError
(
f
"Failed to dequeue task:
{
e
}
"
)
examples/python_rs/llm/vllm/utils/nixl.py
0 → 100644
View file @
ab33729b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
contextlib
import
contextmanager
import
msgspec
from
vllm.distributed.device_communicators.nixl
import
NixlMetadata
from
dynamo.runtime
import
DistributedRuntime
METADATA_DIR
=
"/tmp/nixl"
@
contextmanager
def
temp_metadata_file
(
engine_id
,
metadata
:
NixlMetadata
):
os
.
makedirs
(
METADATA_DIR
,
exist_ok
=
True
)
path
=
f
"
{
METADATA_DIR
}
/
{
engine_id
}
.nixl_meta"
with
open
(
path
,
"wb"
)
as
f
:
encoded
=
msgspec
.
msgpack
.
encode
(
metadata
)
print
(
f
"Size of encoded metadata:
{
len
(
encoded
)
}
"
)
f
.
write
(
encoded
)
try
:
yield
path
finally
:
if
os
.
path
.
exists
(
path
):
os
.
remove
(
path
)
def
find_remote_metadata
(
engine_id
):
# find and load metadata from METADATA_DIR that do not match engine_id
remote_metadata
=
[]
for
file
in
os
.
listdir
(
METADATA_DIR
):
if
file
.
endswith
(
".nixl_meta"
):
if
file
.
split
(
"."
)[
0
]
!=
engine_id
:
with
open
(
os
.
path
.
join
(
METADATA_DIR
,
file
),
"rb"
)
as
f
:
remote_metadata
.
append
(
msgspec
.
msgpack
.
decode
(
f
.
read
(),
type
=
NixlMetadata
)
)
return
remote_metadata
class
NixlMetadataStore
:
NIXL_METADATA_KEY
=
"nixl_metadata"
def
__init__
(
self
,
namespace
:
str
,
runtime
:
DistributedRuntime
)
->
None
:
self
.
_namespace
=
namespace
# TODO Remove metadata from etcd on delete
self
.
_stored
:
set
[
str
]
=
set
()
self
.
_cached
:
dict
[
str
,
NixlMetadata
]
=
{}
self
.
_client
=
runtime
.
etcd_client
()
self
.
_key_prefix
=
f
"
{
self
.
_namespace
}
/
{
NixlMetadataStore
.
NIXL_METADATA_KEY
}
"
async
def
put
(
self
,
engine_id
,
metadata
:
NixlMetadata
):
serialized_metadata
=
msgspec
.
msgpack
.
encode
(
metadata
)
key
=
"/"
.
join
([
self
.
_key_prefix
,
engine_id
])
await
self
.
_client
.
kv_put
(
key
,
serialized_metadata
,
None
)
self
.
_stored
.
add
(
engine_id
)
async
def
get
(
self
,
engine_id
)
->
NixlMetadata
:
try
:
if
engine_id
in
self
.
_cached
:
return
self
.
_cached
[
engine_id
]
key
=
"/"
.
join
([
self
.
_key_prefix
,
engine_id
])
key_values
=
await
self
.
_client
.
kv_get_prefix
(
key
)
deserialized_metadata
=
None
for
item
in
key_values
:
deserialized_metadata
=
msgspec
.
msgpack
.
decode
(
item
[
"value"
],
type
=
NixlMetadata
)
break
if
deserialized_metadata
is
None
:
raise
Exception
(
"metadata not found in etcd"
)
self
.
_cached
[
engine_id
]
=
deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except
Exception
as
e
:
raise
Exception
(
"Error retrieving metadata for engine {engine_id}"
)
from
e
return
deserialized_metadata
examples/python_rs/llm/vllm/utils/prefill_queue.py
0 → 100644
View file @
ab33729b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Optional
import
msgspec
from
utils.nats_queue
import
NATSQueue
from
vllm.remote_prefill
import
RemotePrefillRequest
class
PrefillQueue
(
NATSQueue
):
"""
A wrapper of NATSQueue for PrefillRequest.
The stream name is forced to be "prefill_queue".
"""
def
__init__
(
self
,
stream_name
=
"prefill_queue"
,
nats_server
:
str
=
"nats://localhost:4222"
,
dequeue_timeout
:
float
=
1
,
):
super
().
__init__
(
stream_name
=
stream_name
,
nats_server
=
nats_server
,
dequeue_timeout
=
dequeue_timeout
,
)
async
def
enqueue_prefill_request
(
self
,
prefill_request
:
RemotePrefillRequest
)
->
None
:
encoded_request
=
msgspec
.
json
.
encode
(
prefill_request
)
await
self
.
enqueue_task
(
encoded_request
)
async
def
dequeue_prefill_request
(
self
)
->
Optional
[
RemotePrefillRequest
]:
encoded_request
=
await
self
.
dequeue_task
()
if
encoded_request
is
not
None
:
prefill_request
=
msgspec
.
json
.
decode
(
encoded_request
,
type
=
RemotePrefillRequest
)
return
prefill_request
else
:
return
None
examples/python_rs/llm/vllm/utils/vllm.py
0 → 100644
View file @
ab33729b
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.utils
import
FlexibleArgumentParser
def
parse_vllm_args
()
->
AsyncEngineArgs
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--router"
,
type
=
str
,
choices
=
[
"random"
,
"round-robin"
,
"kv"
],
default
=
"random"
,
help
=
"Router type to use for scheduling requests to workers"
,
)
parser
.
add_argument
(
"--remote-prefill"
,
action
=
"store_true"
,
help
=
"Enable remote prefill"
)
parser
.
add_argument
(
"--conditional-disagg"
,
action
=
"store_true"
,
help
=
"Use disaggregated router to decide whether to prefill locally or remotely"
,
)
parser
.
add_argument
(
"--max-local-prefill-length"
,
type
=
int
,
default
=
1000
,
help
=
"Maximum length of local prefill"
,
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
.
router
=
args
.
router
engine_args
.
remote_prefill
=
args
.
remote_prefill
engine_args
.
conditional_disagg
=
args
.
conditional_disagg
engine_args
.
max_local_prefill_length
=
args
.
max_local_prefill_length
return
engine_args
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment