Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5db6b2c9
Unverified
Commit
5db6b2c9
authored
Mar 04, 2025
by
Nick Hill
Committed by
GitHub
Mar 04, 2025
Browse files
[V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (#13869)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
6247bae6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
40 deletions
+68
-40
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+0
-2
vllm/utils.py
vllm/utils.py
+12
-10
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+56
-28
No files found.
tests/v1/engine/test_llm_engine.py
View file @
5db6b2c9
...
@@ -15,8 +15,6 @@ DTYPE = "half"
...
@@ -15,8 +15,6 @@ DTYPE = "half"
def
_vllm_model
(
apc
:
bool
,
vllm_runner
,
monkeypatch
):
def
_vllm_model
(
apc
:
bool
,
vllm_runner
,
monkeypatch
):
"""Set up VllmRunner instance."""
"""Set up VllmRunner instance."""
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
return
vllm_runner
(
return
vllm_runner
(
MODEL
,
MODEL
,
dtype
=
DTYPE
,
dtype
=
DTYPE
,
...
...
vllm/utils.py
View file @
5db6b2c9
...
@@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str:
...
@@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str:
return
f
"ipc://
{
base_rpc_path
}
/
{
uuid4
()
}
"
return
f
"ipc://
{
base_rpc_path
}
/
{
uuid4
()
}
"
def
get_open_zmq_inproc_path
()
->
str
:
return
f
"inproc://
{
uuid4
()
}
"
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
"""
"""
Get an open port for the vLLM process to listen on.
Get an open port for the vLLM process to listen on.
...
@@ -2108,12 +2112,12 @@ def get_exception_traceback():
...
@@ -2108,12 +2112,12 @@ def get_exception_traceback():
def
make_zmq_socket
(
def
make_zmq_socket
(
ctx
:
Union
[
zmq
.
asyncio
.
Context
,
zmq
.
Context
],
# type: ignore[name-defined]
ctx
:
Union
[
zmq
.
asyncio
.
Context
,
zmq
.
Context
],
# type: ignore[name-defined]
path
:
str
,
path
:
str
,
type
:
Any
,
socket_
type
:
Any
,
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem
=
psutil
.
virtual_memory
()
mem
=
psutil
.
virtual_memory
()
socket
=
ctx
.
socket
(
type
)
socket
=
ctx
.
socket
(
socket_
type
)
# Calculate buffer size based on system memory
# Calculate buffer size based on system memory
total_mem
=
mem
.
total
/
1024
**
3
total_mem
=
mem
.
total
/
1024
**
3
...
@@ -2127,29 +2131,27 @@ def make_zmq_socket(
...
@@ -2127,29 +2131,27 @@ def make_zmq_socket(
else
:
else
:
buf_size
=
-
1
# Use system default buffer size
buf_size
=
-
1
# Use system default buffer size
if
type
==
zmq
.
constants
.
PULL
:
if
socket_
type
==
zmq
.
constants
.
PULL
:
socket
.
setsockopt
(
zmq
.
constants
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVBUF
,
buf_size
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVBUF
,
buf_size
)
socket
.
connect
(
path
)
socket
.
connect
(
path
)
elif
type
==
zmq
.
constants
.
PUSH
:
elif
socket_
type
==
zmq
.
constants
.
PUSH
:
socket
.
setsockopt
(
zmq
.
constants
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDBUF
,
buf_size
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDBUF
,
buf_size
)
socket
.
bind
(
path
)
socket
.
bind
(
path
)
else
:
else
:
raise
ValueError
(
f
"Unknown Socket Type:
{
type
}
"
)
raise
ValueError
(
f
"Unknown Socket Type:
{
socket_
type
}
"
)
return
socket
return
socket
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
zmq_socket_ctx
(
def
zmq_socket_ctx
(
path
:
str
,
socket_type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
(
io_threads
=
2
)
# type: ignore[attr-defined]
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
try
:
try
:
yield
make_zmq_socket
(
ctx
,
path
,
type
)
yield
make_zmq_socket
(
ctx
,
path
,
socket_
type
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
logger
.
debug
(
"Got Keyboard Interrupt."
)
logger
.
debug
(
"Got Keyboard Interrupt."
)
...
...
vllm/v1/engine/core_client.py
View file @
5db6b2c9
...
@@ -18,8 +18,8 @@ import zmq.asyncio
...
@@ -18,8 +18,8 @@ import zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.utils
import
(
get_open_zmq_i
p
c_path
,
kill_process_tree
,
from
vllm.utils
import
(
get_open_zmq_i
npro
c_path
,
get_open_zmq_ipc_path
,
make_zmq_socket
)
kill_process_tree
,
make_zmq_socket
)
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
...
@@ -202,10 +202,11 @@ class BackgroundResources:
...
@@ -202,10 +202,11 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
,
zmq
.
asyncio
.
Context
]
=
None
ctx
:
Union
[
zmq
.
Context
]
=
None
output_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
output_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
input_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
input_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
proc_handle
:
Optional
[
BackgroundProcHandle
]
=
None
proc_handle
:
Optional
[
BackgroundProcHandle
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
def
__call__
(
self
):
def
__call__
(
self
):
"""Clean up background resources."""
"""Clean up background resources."""
...
@@ -218,8 +219,13 @@ class BackgroundResources:
...
@@ -218,8 +219,13 @@ class BackgroundResources:
self
.
output_socket
.
close
(
linger
=
0
)
self
.
output_socket
.
close
(
linger
=
0
)
if
self
.
input_socket
is
not
None
:
if
self
.
input_socket
is
not
None
:
self
.
input_socket
.
close
(
linger
=
0
)
self
.
input_socket
.
close
(
linger
=
0
)
if
self
.
ctx
is
not
None
:
if
self
.
shutdown_path
is
not
None
:
self
.
ctx
.
destroy
(
linger
=
0
)
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with
self
.
ctx
.
socket
(
zmq
.
PAIR
)
as
shutdown_sender
:
shutdown_sender
.
connect
(
self
.
shutdown_path
)
# Send shutdown signal.
shutdown_sender
.
send
(
b
''
)
class
MPClient
(
EngineCoreClient
):
class
MPClient
(
EngineCoreClient
):
...
@@ -261,28 +267,23 @@ class MPClient(EngineCoreClient):
...
@@ -261,28 +267,23 @@ class MPClient(EngineCoreClient):
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
self
.
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
# ZMQ setup.
# ZMQ setup.
self
.
ctx
=
(
sync_ctx
=
zmq
.
Context
()
zmq
.
asyncio
.
Context
()
# type: ignore[attr-defined]
self
.
ctx
=
zmq
.
asyncio
.
Context
(
sync_ctx
)
if
asyncio_mode
else
sync_ctx
if
asyncio_mode
else
zmq
.
Context
())
# type: ignore[attr-defined]
# This will ensure resources created so far are closed
# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# when the client is garbage collected, even if an
# exception is raised mid-construction.
# exception is raised mid-construction.
resources
=
BackgroundResources
(
ctx
=
s
elf
.
ctx
)
self
.
resources
=
BackgroundResources
(
ctx
=
s
ync_
ctx
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
resources
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
resources
)
# Paths
and sockets
for IPC.
# Paths for IPC.
output_path
=
get_open_zmq_ipc_path
()
self
.
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
resources
.
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
# Start EngineCore in background process.
# Start EngineCore in background process.
resources
.
proc_handle
=
BackgroundProcHandle
(
self
.
resources
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
input_path
=
input_path
,
output_path
=
output_path
,
output_path
=
self
.
output_path
,
process_name
=
"EngineCore"
,
process_name
=
"EngineCore"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
process_kwargs
=
{
...
@@ -291,8 +292,10 @@ class MPClient(EngineCoreClient):
...
@@ -291,8 +292,10 @@ class MPClient(EngineCoreClient):
"log_stats"
:
log_stats
,
"log_stats"
:
log_stats
,
})
})
self
.
output_socket
=
resources
.
output_socket
# Create input socket.
self
.
input_socket
=
resources
.
input_socket
self
.
resources
.
input_socket
=
make_zmq_socket
(
self
.
ctx
,
input_path
,
zmq
.
constants
.
PUSH
)
self
.
input_socket
=
self
.
resources
.
input_socket
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
self
.
utility_results
:
dict
[
int
,
AnyFuture
]
=
{}
def
shutdown
(
self
):
def
shutdown
(
self
):
...
@@ -325,27 +328,48 @@ class SyncMPClient(MPClient):
...
@@ -325,27 +328,48 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
# a ref to the client which prevents gc.
output_socket
=
self
.
output_socket
ctx
=
self
.
ctx
output_path
=
self
.
output_path
decoder
=
self
.
decoder
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
outputs_queue
=
self
.
outputs_queue
shutdown_path
=
get_open_zmq_inproc_path
()
self
.
resources
.
shutdown_path
=
shutdown_path
def
process_outputs_socket
():
def
process_outputs_socket
():
shutdown_socket
=
ctx
.
socket
(
zmq
.
PAIR
)
shutdown_socket
.
bind
(
shutdown_path
)
out_socket
=
make_zmq_socket
(
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
try
:
try
:
poller
=
zmq
.
Poller
()
poller
.
register
(
shutdown_socket
)
poller
.
register
(
out_socket
)
while
True
:
while
True
:
(
frame
,
)
=
output_socket
.
recv_multipart
(
copy
=
False
)
socks
=
poller
.
poll
()
if
not
socks
:
continue
if
len
(
socks
)
==
2
or
socks
[
0
][
0
]
==
shutdown_socket
:
# shutdown signal, exit thread.
break
(
frame
,
)
=
out_socket
.
recv_multipart
(
copy
=
False
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
outputs
=
decoder
.
decode
(
frame
.
buffer
)
if
outputs
.
utility_output
:
if
outputs
.
utility_output
:
_process_utility_output
(
outputs
.
utility_output
,
_process_utility_output
(
outputs
.
utility_output
,
utility_results
)
utility_results
)
else
:
else
:
outputs_queue
.
put_nowait
(
outputs
)
outputs_queue
.
put_nowait
(
outputs
)
except
zmq
.
error
.
ContextTerminated
:
finally
:
# Expected when the class is GC'd / during process termination.
# Close sockets.
pass
shutdown_socket
.
close
(
linger
=
0
)
out_socket
.
close
(
linger
=
0
)
# Process outputs from engine in separate thread.
# Process outputs from engine in separate thread.
Thread
(
target
=
process_outputs_socket
,
daemon
=
True
).
start
()
self
.
output_queue_thread
=
Thread
(
target
=
process_outputs_socket
,
name
=
"EngineCoreOutputQueueThread"
,
daemon
=
True
)
self
.
output_queue_thread
.
start
()
def
get_output
(
self
)
->
EngineCoreOutputs
:
def
get_output
(
self
)
->
EngineCoreOutputs
:
return
self
.
outputs_queue
.
get
()
return
self
.
outputs_queue
.
get
()
...
@@ -424,10 +448,13 @@ class AsyncMPClient(MPClient):
...
@@ -424,10 +448,13 @@ class AsyncMPClient(MPClient):
# Perform IO in separate task to parallelize as much as possible.
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
# Avoid task having direct reference back to the client.
self
.
outputs_queue
=
asyncio
.
Queue
()
self
.
outputs_queue
=
asyncio
.
Queue
()
output_socket
=
self
.
output_socket
decoder
=
self
.
decoder
decoder
=
self
.
decoder
utility_results
=
self
.
utility_results
utility_results
=
self
.
utility_results
outputs_queue
=
self
.
outputs_queue
outputs_queue
=
self
.
outputs_queue
output_path
=
self
.
output_path
output_socket
=
make_zmq_socket
(
self
.
ctx
,
output_path
,
zmq
.
constants
.
PULL
)
self
.
resources
.
output_socket
=
output_socket
async
def
process_outputs_socket
():
async
def
process_outputs_socket
():
while
True
:
while
True
:
...
@@ -439,7 +466,8 @@ class AsyncMPClient(MPClient):
...
@@ -439,7 +466,8 @@ class AsyncMPClient(MPClient):
else
:
else
:
outputs_queue
.
put_nowait
(
outputs
)
outputs_queue
.
put_nowait
(
outputs
)
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
())
self
.
queue_task
=
asyncio
.
create_task
(
process_outputs_socket
(),
name
=
"EngineCoreOutputQueueTask"
)
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
async
def
get_output_async
(
self
)
->
EngineCoreOutputs
:
if
self
.
outputs_queue
is
None
:
if
self
.
outputs_queue
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment