Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b07bf83c
Unverified
Commit
b07bf83c
authored
Apr 25, 2025
by
Nick Hill
Committed by
GitHub
Apr 26, 2025
Browse files
[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
53e8cf53
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
12 deletions
+77
-12
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+3
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+22
-3
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+52
-9
No files found.
tests/v1/test_serial_utils.py
View file @
b07bf83c
...
@@ -32,6 +32,7 @@ class MyType:
...
@@ -32,6 +32,7 @@ class MyType:
large_f_contig_tensor
:
torch
.
Tensor
large_f_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
empty_tensor
:
torch
.
Tensor
def
test_encode_decode
():
def
test_encode_decode
():
...
@@ -58,6 +59,7 @@ def test_encode_decode():
...
@@ -58,6 +59,7 @@ def test_encode_decode():
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
empty_tensor
=
torch
.
empty
(
0
),
)
)
encoder
=
MsgpackEncoder
(
size_threshold
=
256
)
encoder
=
MsgpackEncoder
(
size_threshold
=
256
)
...
@@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
...
@@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2
.
small_non_contig_tensor
)
obj2
.
small_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
obj2
.
large_non_contig_tensor
)
obj2
.
large_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
empty_tensor
,
obj2
.
empty_tensor
)
vllm/v1/engine/core.py
View file @
b07bf83c
...
@@ -5,6 +5,7 @@ import signal
...
@@ -5,6 +5,7 @@ import signal
import
sys
import
sys
import
threading
import
threading
import
time
import
time
from
collections
import
deque
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
inspect
import
isclass
,
signature
from
inspect
import
isclass
,
signature
from
logging
import
DEBUG
from
logging
import
DEBUG
...
@@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
...
@@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
# Msgpack serialization encoding.
# Msgpack serialization encoding.
encoder
=
MsgpackEncoder
()
encoder
=
MsgpackEncoder
()
# Reuse send buffer.
# Send buffers to reuse.
buffer
=
bytearray
()
reuse_buffers
:
list
[
bytearray
]
=
[]
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending
=
deque
[
tuple
[
zmq
.
MessageTracker
,
Any
,
bytearray
]]()
# We must set linger to ensure the ENGINE_CORE_DEAD
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
# message is sent prior to closing the socket.
...
@@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
...
@@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
break
break
assert
not
isinstance
(
outputs
,
bytes
)
assert
not
isinstance
(
outputs
,
bytes
)
outputs
.
engine_index
=
engine_index
outputs
.
engine_index
=
engine_index
# Reclaim buffers that zmq is finished with.
while
pending
and
pending
[
-
1
][
0
].
done
:
reuse_buffers
.
append
(
pending
.
pop
()[
2
])
buffer
=
reuse_buffers
.
pop
()
if
reuse_buffers
else
bytearray
()
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
buffers
=
encoder
.
encode_into
(
outputs
,
buffer
)
socket
.
send_multipart
(
buffers
,
copy
=
False
)
tracker
=
socket
.
send_multipart
(
buffers
,
copy
=
False
,
track
=
True
)
if
not
tracker
.
done
:
ref
=
outputs
if
len
(
buffers
)
>
1
else
None
pending
.
appendleft
((
tracker
,
ref
,
buffer
))
elif
len
(
reuse_buffers
)
<
2
:
# Keep at most 2 buffers to reuse.
reuse_buffers
.
append
(
buffer
)
class
DPEngineCoreProc
(
EngineCoreProc
):
class
DPEngineCoreProc
(
EngineCoreProc
):
...
...
vllm/v1/engine/core_client.py
View file @
b07bf83c
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
import
contextlib
import
queue
import
queue
import
uuid
import
uuid
import
weakref
import
weakref
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
collections.abc
import
Awaitable
,
Sequence
from
collections.abc
import
Awaitable
,
Sequence
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
...
@@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
self
.
_wait_for_engine_startup
()
self
.
_wait_for_engine_startup
()
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
# Request objects which may contain pytorch-allocated tensors
# that we need to keep references to until zmq is done with the
# underlying data.
self
.
pending_messages
=
deque
[
tuple
[
zmq
.
MessageTracker
,
Any
]]()
success
=
True
success
=
True
finally
:
finally
:
if
not
success
:
if
not
success
:
...
@@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
...
@@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
if
self
.
resources
.
engine_dead
:
if
self
.
resources
.
engine_dead
:
raise
EngineDeadError
()
raise
EngineDeadError
()
def
add_pending_message
(
self
,
tracker
:
zmq
.
MessageTracker
,
msg
:
Any
):
if
not
tracker
.
done
:
self
.
pending_messages
.
appendleft
((
tracker
,
msg
))
def
free_pending_messages
(
self
):
while
self
.
pending_messages
and
self
.
pending_messages
[
-
1
][
0
].
done
:
self
.
pending_messages
.
pop
()
def
_process_utility_output
(
output
:
UtilityOutput
,
def
_process_utility_output
(
output
:
UtilityOutput
,
utility_results
:
dict
[
int
,
AnyFuture
]):
utility_results
:
dict
[
int
,
AnyFuture
]):
...
@@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
...
@@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Any
):
self
.
ensure_alive
()
self
.
ensure_alive
()
self
.
free_pending_messages
()
# (Identity, RequestType, SerializedRequest)
# (Identity, RequestType, SerializedRequest)
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
msg
=
(
self
.
core_engine
.
identity
,
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
*
self
.
encoder
.
encode
(
request
))
if
len
(
msg
)
<=
3
:
# No auxiliary buffers => no tensor backing buffers in request.
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
return
tracker
=
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
,
track
=
True
)
self
.
add_pending_message
(
tracker
,
request
)
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
def
call_utility
(
self
,
method
:
str
,
*
args
)
->
Any
:
call_id
=
uuid
.
uuid1
().
int
>>
64
call_id
=
uuid
.
uuid1
().
int
>>
64
...
@@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
...
@@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
def
_send_input
(
self
,
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request_type
:
EngineCoreRequestType
,
request
:
Any
,
request
:
Any
,
engine
:
Optional
[
CoreEngine
]
=
None
)
->
Awaitable
[
None
]:
engine
:
Optional
[
CoreEngine
]
=
None
)
->
Awaitable
[
Any
]:
self
.
ensure_alive
()
self
.
ensure_alive
()
if
engine
is
None
:
if
engine
is
None
:
engine
=
self
.
core_engine
engine
=
self
.
core_engine
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
message
=
(
request_type
.
value
,
*
self
.
encoder
.
encode
(
request
))
return
self
.
_send_input_message
(
message
,
engine
)
return
self
.
_send_input_message
(
message
,
engine
,
request
)
def
_send_input_message
(
self
,
message
:
tuple
[
bytestr
,
...],
def
_send_input_message
(
self
,
message
:
tuple
[
bytestr
,
engine
:
CoreEngine
)
->
Awaitable
[
None
]:
...],
engine
:
CoreEngine
,
objects
:
Any
)
->
Awaitable
[
Any
]:
"""
objects is a reference to retain until zmq is finished with the
buffers, in case they were extracted from tensors in the request.
"""
self
.
ensure_alive
()
self
.
ensure_alive
()
message
=
(
engine
.
identity
,
)
+
message
self
.
free_pending_messages
()
return
self
.
input_socket
.
send_multipart
(
message
,
copy
=
False
)
msg
=
(
engine
.
identity
,
)
+
message
if
not
objects
or
len
(
msg
)
<=
3
:
# No auxiliary buffers => no tensor backing buffers in request.
return
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
future
:
asyncio
.
Future
[
zmq
.
MessageTracker
]
future
=
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
,
track
=
True
)
def
add_pending
(
f
:
asyncio
.
Future
[
zmq
.
MessageTracker
]):
with
contextlib
.
suppress
(
BaseException
):
self
.
add_pending_message
(
f
.
result
(),
objects
)
future
.
add_done_callback
(
add_pending
)
return
future
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
async
def
call_utility_async
(
self
,
method
:
str
,
*
args
)
->
Any
:
return
await
self
.
_call_utility_async
(
method
,
return
await
self
.
_call_utility_async
(
method
,
...
@@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
...
@@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
self
.
utility_results
[
call_id
]
=
future
self
.
utility_results
[
call_id
]
=
future
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
message
=
(
EngineCoreRequestType
.
UTILITY
.
value
,
*
self
.
encoder
.
encode
(
(
call_id
,
method
,
args
)))
(
call_id
,
method
,
args
)))
await
self
.
_send_input_message
(
message
,
engine
)
await
self
.
_send_input_message
(
message
,
engine
,
args
)
self
.
_ensure_output_queue_task
()
self
.
_ensure_output_queue_task
()
return
await
future
return
await
future
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment