Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd143ef5
Unverified
Commit
dd143ef5
authored
Apr 10, 2025
by
Nick Hill
Committed by
GitHub
Apr 10, 2025
Browse files
[V1] Zero-copy tensor/ndarray serialization/transmission (#13790)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
daefed05
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
217 additions
and
58 deletions
+217
-58
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+80
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-4
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+13
-13
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+120
-41
No files found.
tests/v1/test_serial_utils.py
0 → 100644
View file @
dd143ef5
# SPDX-License-Identifier: Apache-2.0
from
collections
import
UserDict
from
dataclasses
import
dataclass
import
numpy
as
np
import
torch
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
class
UnrecognizedType
(
UserDict
):
def
__init__
(
self
,
an_int
:
int
):
super
().
__init__
()
self
.
an_int
=
an_int
@
dataclass
class
MyType
:
tensor1
:
torch
.
Tensor
a_string
:
str
list_of_tensors
:
list
[
torch
.
Tensor
]
numpy_array
:
np
.
ndarray
unrecognized
:
UnrecognizedType
def
test_encode_decode
():
"""Test encode/decode loop with zero-copy tensors."""
obj
=
MyType
(
tensor1
=
torch
.
randint
(
low
=
0
,
high
=
100
,
size
=
(
1024
,
),
dtype
=
torch
.
int32
),
a_string
=
"hello"
,
list_of_tensors
=
[
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
tensor
(
1984
),
# test scalar too
],
numpy_array
=
np
.
arange
(
512
),
unrecognized
=
UnrecognizedType
(
33
),
)
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyType
)
encoded
=
encoder
.
encode
(
obj
)
# There should be the main buffer + 2 large tensor buffers
# + 1 large numpy array. "large" is <= 256 bytes.
# The two small tensors are encoded inline.
assert
len
(
encoded
)
==
4
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
assert_equal
(
decoded
,
obj
)
# Test encode_into case
preallocated
=
bytearray
()
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
assert
len
(
encoded2
)
==
4
assert
encoded2
[
0
]
is
preallocated
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
assert_equal
(
decoded2
,
obj
)
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
obj1
.
a_string
==
obj2
.
a_string
assert
all
(
torch
.
equal
(
a
,
b
)
for
a
,
b
in
zip
(
obj1
.
list_of_tensors
,
obj2
.
list_of_tensors
))
assert
np
.
array_equal
(
obj1
.
numpy_array
,
obj2
.
numpy_array
)
assert
obj1
.
unrecognized
.
an_int
==
obj2
.
unrecognized
.
an_int
vllm/v1/engine/core.py
View file @
dd143ef5
...
@@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore):
...
@@ -490,14 +490,14 @@ class EngineCoreProc(EngineCore):
while
True
:
while
True
:
# (RequestType, RequestData)
# (RequestType, RequestData)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
type_frame
,
*
data_frame
s
=
socket
.
recv_multipart
(
copy
=
False
)
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
# Deserialize the request data.
decoder
=
add_request_decoder
if
(
decoder
=
add_request_decoder
if
(
request_type
request_type
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
==
EngineCoreRequestType
.
ADD
)
else
generic_decoder
request
=
decoder
.
decode
(
data_frame
.
buffer
)
request
=
decoder
.
decode
(
data_frame
s
)
# Push to input queue for core busy loop.
# Push to input queue for core busy loop.
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
self
.
input_queue
.
put_nowait
((
request_type
,
request
))
...
@@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore):
...
@@ -514,8 +514,8 @@ class EngineCoreProc(EngineCore):
while
True
:
while
True
:
outputs
=
self
.
output_queue
.
get
()
outputs
=
self
.
output_queue
.
get
()
outputs
.
engine_index
=
engine_index
outputs
.
engine_index
=
engine_index
encoder
.
encode_into
(
outputs
,
buffer
)
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send
(
buffer
,
copy
=
False
)
socket
.
send
_multipart
(
buffer
s
,
copy
=
False
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
ENGINE_PAUSED_OUTPUTS
=
EngineCoreOutputs
(
engine_paused
=
True
)
...
...
vllm/v1/engine/core_client.py
View file @
dd143ef5
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
...
@@ -26,7 +26,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
,
bytestr
from
vllm.v1.utils
import
BackgroundProcHandle
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -505,8 +505,8 @@ class SyncMPClient(MPClient):
...
@@ -505,8 +505,8 @@ class SyncMPClient(MPClient):
# shutdown signal, exit thread.
# shutdown signal, exit thread.
break
break
frame
=
out_socket
.
recv
(
copy
=
False
)
frame
s
=
out_socket
.
recv
_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
...
@@ -529,7 +529,7 @@ class SyncMPClient(MPClient):
...
@@ -529,7 +529,7 @@ class SyncMPClient(MPClient):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
# (Identity, RequestType, SerializedRequest)
# (Identity, RequestType, SerializedRequest)
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
*
self
.
encoder
.
encode
(
request
))
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
...
@@ -633,8 +633,8 @@ class AsyncMPClient(MPClient):
...
@@ -633,8 +633,8 @@ class AsyncMPClient(MPClient):
async
def
process_outputs_socket
():
async
def
process_outputs_socket
():
while
True
:
while
True
:
(
frame
,
)
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
frame
s
=
await
output_socket
.
recv_multipart
(
copy
=
False
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
frame
s
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
...
@@ -666,12 +666,12 @@ class AsyncMPClient(MPClient):
...
@@ -666,12 +666,12 @@ class AsyncMPClient(MPClient):
if
engine
is
None
:
if
engine
is
None
:
engine
=
self
.
core_engine
engine
=
self
.
core_engine
message
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
return
self
.
_send_input_message
(
message
,
engine
)
return
self
.
_send_input_message
(
message
,
engine
)
def
_send_input_message
(
self
,
message
:
tuple
[
bytes
,
bytes
],
def
_send_input_message
(
self
,
message
:
tuple
[
bytes
tr
,
...
],
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
message
=
(
engine
.
identity
,
)
+
message
# type: ignore[assignment]
message
=
(
engine
.
identity
,
)
+
message
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
...
@@ -684,8 +684,8 @@ class AsyncMPClient(MPClient):
...
@@ -684,8 +684,8 @@ class AsyncMPClient(MPClient):
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
future
=
asyncio
.
get_running_loop
().
create_future
()
future
=
asyncio
.
get_running_loop
().
create_future
()
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
(
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engine
)
await
self
.
_send_input_message
(
message
,
engine
)
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
return
await
future
return
await
future
...
@@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -760,7 +760,7 @@ class DPAsyncMPClient(AsyncMPClient):
# Control message used for triggering dp idle mode loop.
# Control message used for triggering dp idle mode loop.
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
start_dp_msg
=
(
EngineCoreRequestType
.
START_DP
.
value
,
self
.
encoder
.
encode
(
None
))
*
self
.
encoder
.
encode
(
None
))
self
.
num_engines_running
=
0
self
.
num_engines_running
=
0
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
self
.
reqs_in_flight
:
dict
[
str
,
CoreEngine
]
=
{}
...
@@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
...
@@ -794,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient):
# tokenized.
# tokenized.
request
.
prompt
=
None
request
.
prompt
=
None
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
EngineCoreRequestType
.
ADD
.
value
,
*
self
.
encoder
.
encode
(
request
))
chosen_engine
=
self
.
get_core_engine_for_request
()
chosen_engine
=
self
.
get_core_engine_for_request
()
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
self
.
reqs_in_flight
[
request
.
request_id
]
=
chosen_engine
...
...
vllm/v1/serial_utils.py
View file @
dd143ef5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pickle
import
pickle
from
collections.abc
import
Sequence
from
inspect
import
isclass
from
types
import
FunctionType
from
types
import
FunctionType
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
import
numpy
as
np
import
torch
import
torch
import
zmq
from
msgspec
import
msgpack
from
msgspec
import
msgpack
CUSTOM_TYPE_TENSOR
=
1
CUSTOM_TYPE_PICKLE
=
1
CUSTOM_TYPE_PICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
2
CUSTOM_TYPE_CLOUDPICKLE
=
3
# TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD
=
256
class
MsgpackEncoder
:
bytestr
=
Union
[
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
]
"""Encoder with custom torch tensor serialization."""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
custom_enc_hook
)
def
encode
(
self
,
obj
:
Any
)
->
bytes
:
class
MsgpackEncoder
:
return
self
.
encoder
.
encode
(
obj
)
"""Encoder with custom torch tensor and numpy array serialization.
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
None
:
Note that unlike vanilla `msgspec` Encoders, this interface is generally
self
.
encoder
.
encode_into
(
obj
,
buf
)
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
):
self
.
encoder
=
msgpack
.
Encoder
(
enc_hook
=
self
.
enc_hook
)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self
.
aux_buffers
:
Optional
[
list
[
bytestr
]]
=
None
def
encode
(
self
,
obj
:
Any
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
bufs
=
[
b
''
]
bufs
[
0
]
=
self
.
encoder
.
encode
(
obj
)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return
bufs
finally
:
self
.
aux_buffers
=
None
def
encode_into
(
self
,
obj
:
Any
,
buf
:
bytearray
)
->
Sequence
[
bytestr
]:
try
:
self
.
aux_buffers
=
[
buf
]
bufs
=
self
.
aux_buffers
self
.
encoder
.
encode_into
(
obj
,
buf
)
return
bufs
finally
:
self
.
aux_buffers
=
None
def
enc_hook
(
self
,
obj
:
Any
)
->
Any
:
if
isinstance
(
obj
,
torch
.
Tensor
):
return
self
.
_encode_ndarray
(
obj
.
numpy
())
# Fall back to pickle for object or void kind ndarrays.
if
isinstance
(
obj
,
np
.
ndarray
)
and
obj
.
dtype
.
kind
not
in
(
'O'
,
'V'
):
return
self
.
_encode_ndarray
(
obj
)
if
isinstance
(
obj
,
FunctionType
):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
))
def
_encode_ndarray
(
self
,
obj
:
np
.
ndarray
)
->
tuple
[
str
,
tuple
[
int
,
...],
Union
[
int
,
memoryview
]]:
assert
self
.
aux_buffers
is
not
None
if
not
obj
.
shape
or
obj
.
nbytes
<
INLINE_BUF_SIZE_THRESHOLD
:
# Encode small arrays and scalars inline.
data
=
obj
.
data
else
:
# Otherwise encode index of backing buffer.
obj
=
np
.
ascontiguousarray
(
obj
)
data
=
len
(
self
.
aux_buffers
)
self
.
aux_buffers
.
append
(
obj
.
data
)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return
obj
.
dtype
.
str
,
obj
.
shape
,
data
class
MsgpackDecoder
:
class
MsgpackDecoder
:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
def
__init__
(
self
,
t
:
Optional
[
Any
]
=
None
):
args
=
()
if
t
is
None
else
(
t
,
)
args
=
()
if
t
is
None
else
(
t
,
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
custom_ext_hook
)
self
.
decoder
=
msgpack
.
Decoder
(
*
args
,
ext_hook
=
self
.
ext_hook
,
def
decode
(
self
,
obj
:
Any
):
dec_hook
=
self
.
dec_hook
)
return
self
.
decoder
.
decode
(
obj
)
self
.
aux_buffers
:
Sequence
[
bytestr
]
=
()
def
decode
(
self
,
bufs
:
Union
[
bytestr
,
Sequence
[
bytestr
]])
->
Any
:
def
custom_enc_hook
(
obj
:
Any
)
->
Any
:
if
isinstance
(
bufs
,
(
bytes
,
bytearray
,
memoryview
,
zmq
.
Frame
)):
if
isinstance
(
obj
,
torch
.
Tensor
):
# TODO - This check can become `isinstance(bufs, bytestr)`
# NOTE(rob): it is fastest to use numpy + pickle
# as of Python 3.10.
# when serializing torch tensors.
return
self
.
decoder
.
decode
(
bufs
)
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return
msgpack
.
Ext
(
CUSTOM_TYPE_TENSOR
,
pickle
.
dumps
(
obj
.
numpy
()))
self
.
aux_buffers
=
bufs
try
:
if
isinstance
(
obj
,
FunctionType
):
return
self
.
decoder
.
decode
(
bufs
[
0
])
return
msgpack
.
Ext
(
CUSTOM_TYPE_CLOUDPICKLE
,
cloudpickle
.
dumps
(
obj
))
finally
:
self
.
aux_buffers
=
()
return
msgpack
.
Ext
(
CUSTOM_TYPE_PICKLE
,
pickle
.
dumps
(
obj
))
def
dec_hook
(
self
,
t
:
type
,
obj
:
Any
)
->
Any
:
# Given native types in `obj`, convert to type `t`.
def
custom_ext_hook
(
code
:
int
,
data
:
memoryview
)
->
Any
:
if
isclass
(
t
):
if
code
==
CUSTOM_TYPE_TENSOR
:
if
issubclass
(
t
,
np
.
ndarray
):
return
torch
.
from_numpy
(
pickle
.
loads
(
data
))
return
self
.
_decode_ndarray
(
obj
)
if
code
==
CUSTOM_TYPE_PICKLE
:
if
issubclass
(
t
,
torch
.
Tensor
):
return
pickle
.
loads
(
data
)
return
torch
.
from_numpy
(
self
.
_decode_ndarray
(
obj
))
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
obj
return
cloudpickle
.
loads
(
data
)
def
_decode_ndarray
(
self
,
arr
:
Any
)
->
np
.
ndarray
:
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
dtype
,
shape
,
data
=
arr
buffer
=
self
.
aux_buffers
[
data
]
if
isinstance
(
data
,
int
)
else
data
return
np
.
ndarray
(
buffer
=
buffer
,
dtype
=
np
.
dtype
(
dtype
),
shape
=
shape
)
def
ext_hook
(
self
,
code
:
int
,
data
:
memoryview
)
->
Any
:
if
code
==
CUSTOM_TYPE_PICKLE
:
return
pickle
.
loads
(
data
)
if
code
==
CUSTOM_TYPE_CLOUDPICKLE
:
return
cloudpickle
.
loads
(
data
)
raise
NotImplementedError
(
f
"Extension type code
{
code
}
is not supported"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment