Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56bd537d
Unverified
Commit
56bd537d
authored
Jul 30, 2025
by
Nick Hill
Committed by
GitHub
Jul 30, 2025
Browse files
[Misc] Support more collective_rpc return types (#21845)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
8f0d5167
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
6 deletions
+121
-6
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+64
-1
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+8
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-3
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+2
-1
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+44
-0
No files found.
tests/v1/engine/test_engine_core_client.py
View file @
56bd537d
...
@@ -6,8 +6,9 @@ import os
...
@@ -6,8 +6,9 @@ import os
import
signal
import
signal
import
time
import
time
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
pytest
import
pytest
...
@@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
...
@@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
client
.
shutdown
()
client
.
shutdown
()
@
dataclass
class
MyDataclass
:
message
:
str
# Dummy utility function to monkey-patch into engine core.
def
echo_dc
(
self
,
msg
:
str
,
return_list
:
bool
=
False
,
)
->
Union
[
MyDataclass
,
list
[
MyDataclass
]]:
print
(
f
"echo dc util function called:
{
msg
}
"
)
# Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec).
return
[
MyDataclass
(
msg
)
for
_
in
range
(
3
)]
if
return_list
\
else
MyDataclass
(
msg
)
@
pytest
.
mark
.
asyncio
(
loop_scope
=
"function"
)
async
def
test_engine_core_client_util_method_custom_return
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# Must set insecure serialization to allow returning custom types.
m
.
setenv
(
"VLLM_ALLOW_INSECURE_SERIALIZATION"
,
"1"
)
# Monkey-patch core engine utility function to test.
m
.
setattr
(
EngineCore
,
"echo_dc"
,
echo_dc
,
raising
=
False
)
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
enforce_eager
=
True
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
with
set_default_torch_num_threads
(
1
):
client
=
EngineCoreClient
.
make_client
(
multiprocess_mode
=
True
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
True
,
)
try
:
# Test utility method returning custom / non-native data type.
core_client
:
AsyncMPClient
=
client
result
=
await
core_client
.
call_utility_async
(
"echo_dc"
,
"testarg2"
,
False
)
assert
isinstance
(
result
,
MyDataclass
)
and
result
.
message
==
"testarg2"
result
=
await
core_client
.
call_utility_async
(
"echo_dc"
,
"testarg2"
,
True
)
assert
isinstance
(
result
,
list
)
and
all
(
isinstance
(
r
,
MyDataclass
)
and
r
.
message
==
"testarg2"
for
r
in
result
)
finally
:
client
.
shutdown
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode,publisher_config"
,
"multiprocessing_mode,publisher_config"
,
[(
True
,
"tcp"
),
(
False
,
"inproc"
)],
[(
True
,
"tcp"
),
(
False
,
"inproc"
)],
...
...
vllm/v1/engine/__init__.py
View file @
56bd537d
...
@@ -123,6 +123,13 @@ class EngineCoreOutput(
...
@@ -123,6 +123,13 @@ class EngineCoreOutput(
return
self
.
finish_reason
is
not
None
return
self
.
finish_reason
is
not
None
class
UtilityResult
:
"""Wrapper for special handling when serializing/deserializing."""
def
__init__
(
self
,
r
:
Any
=
None
):
self
.
result
=
r
class
UtilityOutput
(
class
UtilityOutput
(
msgspec
.
Struct
,
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
...
@@ -132,7 +139,7 @@ class UtilityOutput(
...
@@ -132,7 +139,7 @@ class UtilityOutput(
# Non-None implies the call failed, result should be None.
# Non-None implies the call failed, result should be None.
failure_message
:
Optional
[
str
]
=
None
failure_message
:
Optional
[
str
]
=
None
result
:
Any
=
None
result
:
Optional
[
UtilityResult
]
=
None
class
EngineCoreOutputs
(
class
EngineCoreOutputs
(
...
...
vllm/v1/engine/core.py
View file @
56bd537d
...
@@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
...
@@ -36,7 +36,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestType
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
UtilityOutput
)
UtilityOutput
,
UtilityResult
)
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.engine.mm_input_cache
import
MirroredProcessingCache
from
vllm.v1.engine.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.v1.engine.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
...
@@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
...
@@ -715,8 +715,8 @@ class EngineCoreProc(EngineCore):
output
=
UtilityOutput
(
call_id
)
output
=
UtilityOutput
(
call_id
)
try
:
try
:
method
=
getattr
(
self
,
method_name
)
method
=
getattr
(
self
,
method_name
)
output
.
result
=
method
(
result
=
method
(
*
self
.
_convert_msgspec_args
(
method
,
args
))
*
self
.
_convert_msgspec_args
(
method
,
args
)
)
output
.
result
=
UtilityResult
(
result
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
exception
(
"Invocation of %s method failed"
,
method_name
)
logger
.
exception
(
"Invocation of %s method failed"
,
method_name
)
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
output
.
failure_message
=
(
f
"Call to
{
method_name
}
method"
...
...
vllm/v1/engine/core_client.py
View file @
56bd537d
...
@@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
...
@@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
if
output
.
failure_message
is
not
None
:
if
output
.
failure_message
is
not
None
:
future
.
set_exception
(
Exception
(
output
.
failure_message
))
future
.
set_exception
(
Exception
(
output
.
failure_message
))
else
:
else
:
future
.
set_result
(
output
.
result
)
assert
output
.
result
is
not
None
future
.
set_result
(
output
.
result
.
result
)
class
SyncMPClient
(
MPClient
):
class
SyncMPClient
(
MPClient
):
...
...
vllm/v1/serial_utils.py
View file @
56bd537d
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
dataclasses
import
importlib
import
pickle
import
pickle
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
inspect
import
isclass
from
inspect
import
isclass
...
@@ -9,6 +10,7 @@ from types import FunctionType
...
@@ -9,6 +10,7 @@ from types import FunctionType
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
import
msgspec
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
zmq
import
zmq
...
@@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
...
@@ -22,6 +24,7 @@ from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalFlatField
,
MultiModalKwargs
,
MultiModalFlatField
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalKwargsItem
,
MultiModalSharedField
,
NestedTensors
)
MultiModalSharedField
,
NestedTensors
)
from
vllm.v1.engine
import
UtilityResult
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
...
@@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
)
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
)
def
_typestr
(
t
:
type
):
return
t
.
__module__
,
t
.
__qualname__
class
MsgpackEncoder
:
class
MsgpackEncoder
:
"""Encoder with custom torch tensor and numpy array serialization.
"""Encoder with custom torch tensor and numpy array serialization.
...
@@ -122,6 +129,18 @@ class MsgpackEncoder:
...
@@ -122,6 +129,18 @@ class MsgpackEncoder:
for
itemlist
in
mm
.
_items_by_modality
.
values
()
for
itemlist
in
mm
.
_items_by_modality
.
values
()
for
item
in
itemlist
]
for
item
in
itemlist
]
if
isinstance
(
obj
,
UtilityResult
):
result
=
obj
.
result
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
or
result
is
None
:
return
None
,
result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
cls
=
result
.
__class__
return
_typestr
(
cls
)
if
cls
is
not
list
else
[
_typestr
(
type
(
v
))
for
v
in
result
],
result
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
:
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
:
raise
TypeError
(
f
"Object of type
{
type
(
obj
)
}
is not serializable"
raise
TypeError
(
f
"Object of type
{
type
(
obj
)
}
is not serializable"
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
...
@@ -237,8 +256,33 @@ class MsgpackDecoder:
...
@@ -237,8 +256,33 @@ class MsgpackDecoder:
k
:
self
.
_decode_nested_tensors
(
v
)
k
:
self
.
_decode_nested_tensors
(
v
)
for
k
,
v
in
obj
.
items
()
for
k
,
v
in
obj
.
items
()
})
})
if
t
is
UtilityResult
:
return
self
.
_decode_utility_result
(
obj
)
return
obj
return
obj
def
_decode_utility_result
(
self
,
obj
:
Any
)
->
UtilityResult
:
result_type
,
result
=
obj
if
result_type
is
not
None
:
if
not
envs
.
VLLM_ALLOW_INSECURE_SERIALIZATION
:
raise
TypeError
(
"VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types"
)
assert
isinstance
(
result_type
,
list
)
if
len
(
result_type
)
==
2
and
isinstance
(
result_type
[
0
],
str
):
result
=
self
.
_convert_result
(
result_type
,
result
)
else
:
assert
isinstance
(
result
,
list
)
result
=
[
self
.
_convert_result
(
rt
,
r
)
for
rt
,
r
in
zip
(
result_type
,
result
)
]
return
UtilityResult
(
result
)
def
_convert_result
(
self
,
result_type
:
Sequence
[
str
],
result
:
Any
):
mod_name
,
name
=
result_type
mod
=
importlib
.
import_module
(
mod_name
)
result_type
=
getattr
(
mod
,
name
)
return
msgspec
.
convert
(
result
,
result_type
,
dec_hook
=
self
.
dec_hook
)
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
dtype
,
shape
,
data
=
arr
dtype
,
shape
,
data
=
arr
# zero-copy decode. We assume the ndarray will not be kept around,
# zero-copy decode. We assume the ndarray will not be kept around,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment