Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56bd537d
Unverified
Commit
56bd537d
authored
Jul 30, 2025
by
Nick Hill
Committed by
GitHub
Jul 30, 2025
Browse files
[Misc] Support more collective_rpc return types (#21845)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
8f0d5167
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
6 deletions
+121
-6
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+64
-1
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+8
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-3
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+2
-1
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+44
-0
No files found.
tests/v1/engine/test_engine_core_client.py
View file @
56bd537d
...
...
@@ -6,8 +6,9 @@ import os
import
signal
import
time
import
uuid
from
dataclasses
import
dataclass
from
threading
import
Thread
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
unittest.mock
import
MagicMock
import
pytest
...
...
@@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
client
.
shutdown
()
@
dataclass
class
MyDataclass
:
message
:
str
# Dummy utility function to monkey-patch into engine core.
def
echo_dc
(
self
,
msg
:
str
,
return_list
:
bool
=
False
,
)
->
Union
[
MyDataclass
,
list
[
MyDataclass
]]:
print
(
f
"echo dc util function called:
{
msg
}
"
)
# Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec).
return
[
MyDataclass
(
msg
)
for
_
in
range
(
3
)]
if
return_list
\
else
MyDataclass
(
msg
)
@
pytest
.
mark
.
asyncio
(
loop_scope
=
"function"
)
async
def
test_engine_core_client_util_method_custom_return
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# Must set insecure serialization to allow returning custom types.
m
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
# Monkey-patch core engine utility function to test.
m
.
setattr
(
EngineCore
,
"echo_dc"
,
echo_dc
,
raising
=
False
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
enforce_eager
=
True
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
with
set_default_torch_num_threads
(
1
):
client
=
EngineCoreClient
.
make_client
(
multiprocess_mode
=
True
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
True
,
)
try
:
# Test utility method returning custom / non-native data type.
core_client
:
AsyncMPClient
=
client
result
=
await
core_client
.
call_utility_async
(
"echo_dc"
,
"testarg2"
,
False
)
assert
isinstance
(
result
,
MyDataclass
)
and
result
.
message
==
"testarg2"
result
=
await
core_client
.
call_utility_async
(
"echo_dc"
,
"testarg2"
,
True
)
assert
isinstance
(
result
,
list
)
and
all
(
isinstance
(
r
,
MyDataclass
)
and
r
.
message
==
"testarg2"
for
r
in
result
)
finally
:
client
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode,publisher_config"
,
[(
True
,
"tcp"
),
(
False
,
"inproc"
)],
...
...
vllm/v1/engine/__init__.py
View file @
56bd537d
...
...
@@ -123,6 +123,13 @@ class EngineCoreOutput(
return
self
.
finish_reason
is
not
None
class
UtilityResult
:
"""Wrapper for special handling when serializing/deserializing."""
def
__init__
(
self
,
r
:
Any
=
None
):
self
.
result
=
r
class
UtilityOutput
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
...
...
@@ -132,7 +139,7 @@ class UtilityOutput(
# Non-None implies the call failed, result should be None.
failure_message
:
Optional
[
str
]
=
None
result
:
Any
=
None
result
:
Optional
[
UtilityResult
]
=
None
class
EngineCoreOutputs
(
...
...
vllm/v1/engine/core.py
View file @
56bd537d
...
...
@@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
UtilityOutput
)
UtilityOutput
,
UtilityResult
)
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.engine.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.v1.executor.abstract
import
Executor
...
...
@@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
output
=
UtilityOutput
(
call_id
)
try
:
method
=
getattr
(
self
,
method_name
)
output
.
result
=
method
(
*
self
.
_convert_msgspec_args
(
method
,
args
)
)
result
=
method
(
*
self
.
_convert_msgspec_args
(
method
,
args
))
output
.
result
=
UtilityResult
(
result
)
except
BaseException
as
e
:
logger
.
exception
(
"Invocation of %s method failed"
,
method_name
)
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
...
...
vllm/v1/engine/core_client.py
View file @
56bd537d
...
...
@@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
if
output
.
failure_message
is
not
None
:
future
.
set_exception
(
Exception
(
output
.
failure_message
))
else
:
future
.
set_result
(
output
.
result
)
assert
output
.
result
is
not
None
future
.
set_result
(
output
.
result
.
result
)
class
SyncMPClient
(
MPClient
):
...
...
vllm/v1/serial_utils.py
View file @
56bd537d
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
importlib
import
pickle
from
collections.abc
import
Sequence
from
inspect
import
isclass
...
...
@@ -9,6 +10,7 @@ from types import FunctionType
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
msgspec
import
numpy
as
np
import
torch
import
zmq
...
...
@@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalFlatField
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
NestedTensors
)
from
vllm.v1.engine
import
UtilityResult
logger
=
init_logger
(
__name__
)
...
...
@@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
)
def
_typestr
(
t
:
type
):
return
t
.
__module__
,
t
.
__qualname__
class
MsgpackEncoder
:
"""Encoder with custom torch tensor and numpy array serialization.
...
...
@@ -122,6 +129,18 @@ class MsgpackEncoder:
for
itemlist
in
mm
.
_items_by_modality
.
values
()
for
item
in
itemlist
]
if
isinstance
(
obj
,
UtilityResult
):
result
=
obj
.
result
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
or
result
is
None
:
return
None
,
result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
cls
=
result
.
__class__
return
_typestr
(
cls
)
if
cls
is
not
list
else
[
_typestr
(
type
(
v
))
for
v
in
result
],
result
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
:
raise
TypeError
(
f
"Object of type
{
type
(
obj
)
}
is not serializable"
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
...
...
@@ -237,8 +256,33 @@ class MsgpackDecoder:
k
:
self
.
_decode_nested_tensors
(
v
)
for
k
,
v
in
obj
.
items
()
})
if
t
is
UtilityResult
:
return
self
.
_decode_utility_result
(
obj
)
return
obj
def
_decode_utility_result
(
self
,
obj
:
Any
)
->
UtilityResult
:
result_type
,
result
=
obj
if
result_type
is
not
None
:
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
:
raise
TypeError
(
"VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types"
)
assert
isinstance
(
result_type
,
list
)
if
len
(
result_type
)
==
2
and
isinstance
(
result_type
[
0
],
str
):
result
=
self
.
_convert_result
(
result_type
,
result
)
else
:
assert
isinstance
(
result
,
list
)
result
=
[
self
.
_convert_result
(
rt
,
r
)
for
rt
,
r
in
zip
(
result_type
,
result
)
]
return
UtilityResult
(
result
)
def
_convert_result
(
self
,
result_type
:
Sequence
[
str
],
result
:
Any
):
mod_name
,
name
=
result_type
mod
=
importlib
.
import_module
(
mod_name
)
result_type
=
getattr
(
mod
,
name
)
return
msgspec
.
convert
(
result
,
result_type
,
dec_hook
=
self
.
dec_hook
)
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
dtype
,
shape
,
data
=
arr
# zero-copy decode. We assume the ndarray will not be kept around,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment