Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
caf7ff44
Unverified
Commit
caf7ff44
authored
Feb 19, 2025
by
Nick Hill
Committed by
GitHub
Feb 19, 2025
Browse files
[V1][Core] Generic mechanism for handling engine utility (#13060)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
f525c0be
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
56 deletions
+197
-56
tests/lora/test_add_lora.py
tests/lora/test_add_lora.py
+1
-1
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+49
-8
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+18
-6
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+33
-16
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+96
-25
No files found.
tests/lora/test_add_lora.py
View file @
caf7ff44
...
...
@@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
]
for
tokenizer_file
in
tokenizer_files
:
del_path
=
Path
(
LORA_MODULE_DOWNLOAD_PATH
)
/
tokenizer_file
del_path
.
unlink
()
del_path
.
unlink
(
missing_ok
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
tests/v1/engine/test_engine_core_client.py
View file @
caf7ff44
...
...
@@ -3,7 +3,8 @@
import
asyncio
import
time
import
uuid
from
typing
import
Dict
,
List
from
contextlib
import
ExitStack
from
typing
import
Dict
,
List
,
Optional
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core
import
EngineCore
from
vllm.v1.engine.core_client
import
(
AsyncMPClient
,
EngineCoreClient
,
SyncMPClient
)
from
vllm.v1.executor.abstract
import
Executor
if
not
current_platform
.
is_cuda
():
...
...
@@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async
def
loop_until_done_async
(
client
:
EngineCoreClient
,
outputs
:
Dict
):
while
True
:
engine_core_outputs
=
await
client
.
get_output_async
().
outputs
engine_core_outputs
=
(
await
client
.
get_output_async
()
)
.
outputs
if
len
(
engine_core_outputs
)
==
0
:
break
...
...
@@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break
# Dummy utility function to monkey-patch into engine core.
def
echo
(
self
,
msg
:
str
,
err_msg
:
Optional
[
str
]
=
None
)
->
str
:
print
(
f
"echo util function called:
{
msg
}
,
{
err_msg
}
"
)
if
err_msg
is
not
None
:
raise
ValueError
(
err_msg
)
return
msg
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode"
,
[
True
,
False
])
def
test_engine_core_client
(
monkeypatch
,
multiprocessing_mode
:
bool
):
...
...
@@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
compilation_config
=
3
)
# Monkey-patch core engine utility function to test.
m
.
setattr
(
EngineCore
,
"echo"
,
echo
,
raising
=
False
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
enforce_eager
=
True
)
vllm_config
=
engine_args
.
create_engine_config
(
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
...
...
@@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client
.
abort_requests
([
request
.
request_id
])
if
multiprocessing_mode
:
"""Utility method invocation"""
@
fork_new_process_for_each_test
@
pytest
.
mark
.
asyncio
core_client
:
SyncMPClient
=
client
result
=
core_client
.
_call_utility
(
"echo"
,
"testarg"
)
assert
result
==
"testarg"
with
pytest
.
raises
(
Exception
)
as
e_info
:
core_client
.
_call_utility
(
"echo"
,
None
,
"help!"
)
assert
str
(
e_info
.
value
)
==
"Call to echo method failed: help!"
@
pytest
.
mark
.
asyncio
(
loop_scope
=
"function"
)
async
def
test_engine_core_client_asyncio
(
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
# Monkey-patch core engine utility function to test.
m
.
setattr
(
EngineCore
,
"echo"
,
echo
,
raising
=
False
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
enforce_eager
=
True
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
...
...
@@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
executor_class
=
executor_class
,
log_stats
=
True
,
)
after
.
callback
(
client
.
shutdown
)
MAX_TOKENS
=
20
params
=
SamplingParams
(
max_tokens
=
MAX_TOKENS
)
...
...
@@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
else
:
assert
len
(
outputs
[
req_id
])
==
MAX_TOKENS
,
(
f
"
{
len
(
outputs
[
req_id
])
=
}
,
{
MAX_TOKENS
=
}
"
)
"""Utility method invocation"""
core_client
:
AsyncMPClient
=
client
result
=
await
core_client
.
_call_utility_async
(
"echo"
,
"testarg"
)
assert
result
==
"testarg"
with
pytest
.
raises
(
Exception
)
as
e_info
:
await
core_client
.
_call_utility_async
(
"echo"
,
None
,
"help!"
)
assert
str
(
e_info
.
value
)
==
"Call to echo method failed: help!"
vllm/v1/engine/__init__.py
View file @
caf7ff44
...
...
@@ -2,7 +2,7 @@
import
enum
import
time
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
,
Union
import
msgspec
...
...
@@ -106,6 +106,18 @@ class EngineCoreOutput(
return
self
.
finish_reason
is
not
None
class
UtilityOutput
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
call_id
:
int
# Non-None implies the call failed, result should be None.
failure_message
:
Optional
[
str
]
=
None
result
:
Any
=
None
class
EngineCoreOutputs
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
...
...
@@ -116,10 +128,12 @@ class EngineCoreOutputs(
# e.g. columnwise layout
# [num_reqs]
outputs
:
List
[
EngineCoreOutput
]
scheduler_stats
:
Optional
[
SchedulerStats
]
outputs
:
List
[
EngineCoreOutput
]
=
[]
scheduler_stats
:
Optional
[
SchedulerStats
]
=
None
timestamp
:
float
=
0.0
utility_output
:
Optional
[
UtilityOutput
]
=
None
def
__post_init__
(
self
):
if
self
.
timestamp
==
0.0
:
self
.
timestamp
=
time
.
monotonic
()
...
...
@@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
PROFILE
=
b
'
\x02
'
RESET_PREFIX_CACHE
=
b
'
\x03
'
ADD_LORA
=
b
'
\x04
'
UTILITY
=
b
'
\x02
'
vllm/v1/engine/core.py
View file @
caf7ff44
...
...
@@ -5,9 +5,11 @@ import signal
import
threading
import
time
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
multiprocessing.connection
import
Connection
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Type
import
msgspec
import
psutil
import
zmq
import
zmq.asyncio
...
...
@@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_configs
from
vllm.v1.core.scheduler
import
Scheduler
,
SchedulerOutput
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
self
.
add_request
(
request
)
elif
request_type
==
EngineCoreRequestType
.
ABORT
:
self
.
abort_requests
(
request
)
elif
request_type
==
EngineCoreRequestType
.
RESET_PREFIX_CACHE
:
self
.
reset_prefix_cache
()
elif
request_type
==
EngineCoreRequestType
.
PROFILE
:
self
.
model_executor
.
profile
(
request
)
elif
request_type
==
EngineCoreRequestType
.
ADD_LORA
:
self
.
model_executor
.
add_lora
(
request
)
elif
request_type
==
EngineCoreRequestType
.
UTILITY
:
call_id
,
method_name
,
args
=
request
output
=
UtilityOutput
(
call_id
)
try
:
method
=
getattr
(
self
,
method_name
)
output
.
result
=
method
(
*
self
.
_convert_msgspec_args
(
method
,
args
))
except
BaseException
as
e
:
logger
.
exception
(
"Invocation of %s method failed"
,
method_name
)
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
f
" failed:
{
str
(
e
)
}
"
)
self
.
output_queue
.
put_nowait
(
EngineCoreOutputs
(
utility_output
=
output
))
@
staticmethod
def
_convert_msgspec_args
(
method
,
args
):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if
not
args
:
return
args
arg_types
=
signature
(
method
).
parameters
.
values
()
assert
len
(
args
)
<=
len
(
arg_types
)
return
tuple
(
msgspec
.
convert
(
v
,
type
=
p
.
annotation
)
if
isclass
(
p
.
annotation
)
and
issubclass
(
p
.
annotation
,
msgspec
.
Struct
)
and
not
isinstance
(
v
,
p
.
annotation
)
else
v
for
v
,
p
in
zip
(
args
,
arg_types
))
def
process_input_socket
(
self
,
input_path
:
str
):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder
=
MsgpackDecoder
(
EngineCoreRequest
)
add_lora_decoder
=
MsgpackDecoder
(
LoRARequest
)
generic_decoder
=
MsgpackDecoder
()
with
zmq_socket_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
...
...
@@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
decoder
=
None
if
request_type
==
EngineCoreRequestType
.
ADD
:
decoder
=
add_request_decoder
elif
request_type
==
EngineCoreRequestType
.
ADD_LORA
:
decoder
=
add_lora_decoder
else
:
decoder
=
generic_decoder
decoder
=
add_request_decoder
if
(
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frame
.
buffer
)
# Push to input queue for core busy loop.
...
...
vllm/v1/engine/core_client.py
View file @
caf7ff44
...
...
@@ -2,10 +2,14 @@
import
asyncio
import
os
import
queue
import
signal
import
uuid
import
weakref
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
List
,
Optional
,
Type
from
concurrent.futures
import
Future
from
threading
import
Thread
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
Union
import
zmq
import
zmq.asyncio
...
...
@@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest
from
vllm.utils
import
(
get_open_zmq_ipc_path
,
kill_process_tree
,
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
...
@@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle
logger
=
init_logger
(
__name__
)
AnyFuture
=
Union
[
asyncio
.
Future
[
Any
],
Future
[
Any
]]
class
EngineCoreClient
(
ABC
):
"""
...
...
@@ -204,6 +210,8 @@ class MPClient(EngineCoreClient):
"log_stats"
:
log_stats
,
})
self
.
utility_results
:
Dict
[
int
,
AnyFuture
]
=
{}
def
shutdown
(
self
):
"""Clean up background resources."""
if
hasattr
(
self
,
"proc_handle"
):
...
...
@@ -212,6 +220,16 @@ class MPClient(EngineCoreClient):
self
.
_finalizer
()
def
_process_utility_output
(
output
:
UtilityOutput
,
utility_results
:
Dict
[
int
,
AnyFuture
]):
"""Set the result from a utility method in the waiting future"""
future
=
utility_results
.
pop
(
output
.
call_id
)
if
output
.
failure_message
is
not
None
:
future
.
set_exception
(
Exception
(
output
.
failure_message
))
else
:
future
.
set_result
(
output
.
result
)
class
SyncMPClient
(
MPClient
):
"""Synchronous client for multi-proc EngineCore."""
...
...
@@ -224,10 +242,30 @@ class SyncMPClient(MPClient):
log_stats
=
log_stats
,
)
def
get_output
(
self
)
->
EngineCoreOutputs
:
self
.
outputs_queue
:
queue
.
Queue
[
EngineCoreOutputs
]
=
queue
.
Queue
()
(
frame
,
)
=
self
.
output_socket
.
recv_multipart
(
copy
=
False
)
return
self
.
decoder
.
decode
(
frame
.
buffer
)
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
output_socket
=
self
.
output_socket
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
def
process_outputs_socket
():
while
True
:
(
frame
,
)
=
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
else
:
outputs_queue
.
put_nowait
(
outputs
)
# Process outputs from engine in separate thread.
Thread
(
target
=
process_outputs_socket
,
daemon
=
True
).
start
()
def
get_output
(
self
)
->
EngineCoreOutputs
:
return
self
.
outputs_queue
.
get
()
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
...
...
@@ -236,6 +274,16 @@ class SyncMPClient(MPClient):
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
def
_call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
:
Future
[
Any
]
=
Future
()
self
.
utility_results
[
call_id
]
=
future
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
return
future
.
result
()
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
...
...
@@ -247,13 +295,13 @@ class SyncMPClient(MPClient):
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_
send_input
(
EngineCoreRequestType
.
PROFILE
,
is_start
)
self
.
_
call_utility
(
"profile"
,
is_start
)
def
reset_prefix_cache
(
self
)
->
None
:
self
.
_
send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
None
)
self
.
_
call_utility
(
"reset_prefix_cache"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
self
.
_
send_input
(
EngineCoreRequestType
.
ADD_LORA
,
lora_request
)
self
.
_
call_utility
(
"add_lora"
,
lora_request
)
class
AsyncMPClient
(
MPClient
):
...
...
@@ -268,24 +316,35 @@ class AsyncMPClient(MPClient):
log_stats
=
log_stats
,
)
self
.
outputs_queue
:
Optional
[
asyncio
.
Queue
[
byte
s
]]
=
None
self
.
outputs_queue
:
Optional
[
asyncio
.
Queue
[
EngineCoreOutput
s
]]
=
None
self
.
queue_task
:
Optional
[
asyncio
.
Task
]
=
None
async
def
_start_output_queue_task
(
self
):
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self
.
outputs_queue
=
asyncio
.
Queue
()
output_socket
=
self
.
output_socket
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
async
def
process_outputs_socket
():
while
True
:
(
frame
,
)
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
.
buffer
)
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
else
:
outputs_queue
.
put_nowait
(
outputs
)
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
())
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
if
self
.
outputs_queue
is
None
:
# Perform IO in separate task to parallelize as much as possible
self
.
outputs_queue
=
asyncio
.
Queue
()
async
def
process_outputs_socket
():
assert
self
.
outputs_queue
is
not
None
while
True
:
(
frame
,
)
=
await
self
.
output_socket
.
recv_multipart
(
copy
=
False
)
self
.
outputs_queue
.
put_nowait
(
frame
.
buffer
)
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
())
return
self
.
decoder
.
decode
(
await
self
.
outputs_queue
.
get
())
await
self
.
_start_output_queue_task
()
assert
self
.
outputs_queue
is
not
None
return
await
self
.
outputs_queue
.
get
()
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
)
->
None
:
...
...
@@ -293,6 +352,18 @@ class AsyncMPClient(MPClient):
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
if
self
.
outputs_queue
is
None
:
await
self
.
_start_output_queue_task
()
async
def
_call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
await
self
.
_send_input
(
EngineCoreRequestType
.
UTILITY
,
(
call_id
,
method
,
args
))
return
await
future
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
...
...
@@ -304,10 +375,10 @@ class AsyncMPClient(MPClient):
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_
send_input
(
EngineCoreRequestType
.
PROFILE
,
is_start
)
await
self
.
_
call_utility_async
(
"profile"
,
is_start
)
async
def
reset_prefix_cache_async
(
self
)
->
None
:
await
self
.
_
send_input
(
EngineCoreRequestType
.
RESET_PREFIX_CACHE
,
None
)
await
self
.
_
call_utility_async
(
"reset_prefix_cache"
)
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
None
:
await
self
.
_
send_input
(
EngineCoreRequestType
.
ADD_LORA
,
lora_request
)
await
self
.
_
call_utility_async
(
"add_lora"
,
lora_request
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment