Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df04dffa
Unverified
Commit
df04dffa
authored
Dec 27, 2024
by
Robert Shaw
Committed by
GitHub
Dec 28, 2024
Browse files
[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)
parent
a6073124
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
247 additions
and
215 deletions
+247
-215
docs/requirements-docs.txt
docs/requirements-docs.txt
+1
-0
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+4
-9
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+4
-6
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+10
-1
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+1
-21
vllm/utils.py
vllm/utils.py
+87
-3
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+3
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+22
-89
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+47
-45
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+3
-3
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+5
-6
vllm/v1/utils.py
vllm/v1/utils.py
+60
-29
No files found.
docs/requirements-docs.txt
View file @
df04dffa
...
...
@@ -19,3 +19,4 @@ openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entr
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requests
zmq
tests/v1/engine/test_engine_core.py
View file @
df04dffa
...
...
@@ -7,7 +7,6 @@ from transformers import AutoTokenizer
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.core
import
EngineCore
...
...
@@ -43,13 +42,11 @@ def test_engine_core(monkeypatch):
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
"""Setup the EngineCore."""
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
vllm_config
=
engine_args
.
create_engine_config
()
executor_class
=
AsyncLLM
.
_get_executor_cls
(
vllm_config
)
engine_core
=
EngineCore
(
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
executor_class
)
"""Test basic request lifecycle."""
# First request.
...
...
@@ -151,13 +148,11 @@ def test_engine_core_advanced_sampling(monkeypatch):
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
"""Setup the EngineCore."""
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
vllm_config
=
engine_args
.
create_engine_config
()
executor_class
=
AsyncLLM
.
_get_executor_cls
(
vllm_config
)
engine_core
=
EngineCore
(
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
executor_class
)
"""Test basic request lifecycle."""
# First request.
request
:
EngineCoreRequest
=
make_request
()
...
...
tests/v1/engine/test_engine_core_client.py
View file @
df04dffa
...
...
@@ -86,11 +86,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
AsyncLLM
.
_get_executor_cls
(
vllm_config
)
client
=
EngineCoreClient
.
make_client
(
vllm_config
,
executor_class
,
UsageContext
.
UNKNOWN_CONTEXT
,
multiprocess_mode
=
multiprocessing_mode
,
asyncio_mode
=
False
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
)
MAX_TOKENS
=
20
...
...
@@ -158,11 +157,10 @@ async def test_engine_core_client_asyncio(monkeypatch):
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
)
executor_class
=
AsyncLLM
.
_get_executor_cls
(
vllm_config
)
client
=
EngineCoreClient
.
make_client
(
vllm_config
,
executor_class
,
UsageContext
.
UNKNOWN_CONTEXT
,
multiprocess_mode
=
True
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
)
MAX_TOKENS
=
20
...
...
vllm/entrypoints/openai/api_server.py
View file @
df04dffa
...
...
@@ -68,7 +68,7 @@ from vllm.entrypoints.utils import with_cancellation
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
set_ulimit
)
is_valid_ipv6_address
,
kill_process_tree
,
set_ulimit
)
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -737,6 +737,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
# The child processes will send SIGQUIT to this process when
# any error happens. This process then clean up the whole tree.
# TODO(rob): move this into AsyncLLM.__init__ once we remove
# the context manager below.
def
sigquit_handler
(
signum
,
frame
):
kill_process_tree
(
os
.
getpid
())
signal
.
signal
(
signal
.
SIGQUIT
,
sigquit_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
app
=
build_app
(
args
)
...
...
vllm/executor/multiproc_worker_utils.py
View file @
df04dffa
import
asyncio
import
multiprocessing
import
os
import
sys
import
threading
...
...
@@ -13,10 +12,9 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.utils
import
cuda_is_initialized
from
vllm.utils
import
_check_multiproc_method
,
get_mp_context
if
HAS_TRITON
:
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
...
...
@@ -274,24 +272,6 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
file
.
write
=
write_with_prefix
# type: ignore[method-assign]
def
_check_multiproc_method
():
if
(
cuda_is_initialized
()
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
"CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information."
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
_check_multiproc_method
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
def
set_multiprocessing_worker_envs
(
parallel_config
):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
...
...
vllm/utils.py
View file @
df04dffa
...
...
@@ -10,6 +10,7 @@ import importlib.metadata
import
importlib.util
import
inspect
import
ipaddress
import
multiprocessing
import
os
import
re
import
resource
...
...
@@ -20,6 +21,7 @@ import sys
import
tempfile
import
threading
import
time
import
traceback
import
uuid
import
warnings
import
weakref
...
...
@@ -29,8 +31,9 @@ from collections.abc import Hashable, Iterable, Mapping
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
Dict
,
Generator
,
Generic
,
Iterator
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
from
uuid
import
uuid4
import
numpy
as
np
...
...
@@ -39,6 +42,8 @@ import psutil
import
torch
import
torch.types
import
yaml
import
zmq
import
zmq.asyncio
from
packaging.version
import
Version
from
torch.library
import
Library
from
typing_extensions
import
ParamSpec
,
TypeIs
,
assert_never
...
...
@@ -1844,7 +1849,7 @@ def memory_profiling(
result
.
non_kv_cache_memory_in_bytes
=
result
.
non_torch_increase_in_bytes
+
result
.
torch_peak_increase_in_bytes
+
result
.
weights_memory_in_bytes
# noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/
f46f394f4d4dbe4aae85403dec006199b34d2840
/python/sglang/srt/utils.py#L630 # noqa: E501
Curre
# Adapted from: https://github.com/sgl-project/sglang/blob/
v0.4.1
/python/sglang/srt/utils.py#L630 # noqa: E501
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
...
...
@@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535):
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n"
,
current_soft
,
e
)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
err_str
=
""
.
join
(
traceback
.
format_exception
(
etype
,
value
,
tb
))
return
err_str
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def
make_zmq_socket
(
ctx
:
Union
[
zmq
.
asyncio
.
Context
,
zmq
.
Context
],
# type: ignore[name-defined]
path
:
str
,
type
:
Any
,
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem
=
psutil
.
virtual_memory
()
socket
=
ctx
.
socket
(
type
)
# Calculate buffer size based on system memory
total_mem
=
mem
.
total
/
1024
**
3
available_mem
=
mem
.
available
/
1024
**
3
# For systems with substantial memory (>32GB total, >16GB available):
# - Set a large 0.5GB buffer to improve throughput
# For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption
if
total_mem
>
32
and
available_mem
>
16
:
buf_size
=
int
(
0.5
*
1024
**
3
)
# 0.5GB in bytes
else
:
buf_size
=
-
1
# Use system default buffer size
if
type
==
zmq
.
constants
.
PULL
:
socket
.
setsockopt
(
zmq
.
constants
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
RCVBUF
,
buf_size
)
socket
.
connect
(
path
)
elif
type
==
zmq
.
constants
.
PUSH
:
socket
.
setsockopt
(
zmq
.
constants
.
SNDHWM
,
0
)
socket
.
setsockopt
(
zmq
.
constants
.
SNDBUF
,
buf_size
)
socket
.
bind
(
path
)
else
:
raise
ValueError
(
f
"Unknown Socket Type:
{
type
}
"
)
return
socket
@
contextlib
.
contextmanager
def
zmq_socket_ctx
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
(
io_threads
=
2
)
# type: ignore[attr-defined]
try
:
yield
make_zmq_socket
(
ctx
,
path
,
type
)
except
KeyboardInterrupt
:
logger
.
debug
(
"Got Keyboard Interrupt."
)
finally
:
ctx
.
destroy
(
linger
=
0
)
def
_check_multiproc_method
():
if
(
cuda_is_initialized
()
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
"CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information."
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
_check_multiproc_method
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
vllm/v1/engine/async_llm.py
View file @
df04dffa
...
...
@@ -75,11 +75,11 @@ class AsyncLLM(EngineClient):
# EngineCore (starts the engine in background process).
self
.
engine_core
=
EngineCoreClient
.
make_client
(
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
usage_context
=
usage_context
,
multiprocess_mode
=
True
,
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
self
.
log_stats
,
)
self
.
output_handler
:
Optional
[
asyncio
.
Task
]
=
None
...
...
vllm/v1/engine/core.py
View file @
df04dffa
...
...
@@ -3,20 +3,19 @@ import queue
import
signal
import
threading
import
time
from
dataclasses
import
dataclass
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.connection
import
Connection
from
typing
import
List
,
Tuple
,
Type
import
psutil
import
zmq
import
zmq.asyncio
from
msgspec
import
msgpack
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.executor.multiproc_worker_utils
import
get_mp_context
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.u
sage.usage_lib
import
UsageContext
from
vllm.u
tils
import
get_exception_traceback
,
zmq_socket_ctx
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
...
...
@@ -25,14 +24,13 @@ from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
PickleEncoder
from
vllm.v1.utils
import
make_zmq_socket
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
LOGGING_TIME_S
=
POLLING_TIMEOUT_S
LOGGING_TIME_S
=
5
class
EngineCore
:
...
...
@@ -42,9 +40,10 @@ class EngineCore:
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
usage_context
:
UsageContext
,
log_stats
:
bool
=
False
,
):
assert
vllm_config
.
model_config
.
runner_type
!=
"pooling"
self
.
log_stats
=
log_stats
logger
.
info
(
"Initializing an LLM engine (v%s) with config: %s"
,
VLLM_VERSION
,
vllm_config
)
...
...
@@ -134,29 +133,19 @@ class EngineCore:
self
.
model_executor
.
profile
(
is_start
)
@
dataclass
class
EngineCoreProcHandle
:
proc
:
BaseProcess
ready_path
:
str
input_path
:
str
output_path
:
str
class
EngineCoreProc
(
EngineCore
):
"""ZMQ-wrapper for running EngineCore in background process."""
READY_STR
=
"READY"
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
usage_context
:
UsageContext
,
input_path
:
str
,
output_path
:
str
,
ready_path
:
str
,
ready_pipe
:
Connection
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
log_stats
:
bool
=
False
,
):
super
().
__init__
(
vllm_config
,
executor_class
,
usage_context
)
super
().
__init__
(
vllm_config
,
executor_class
,
log_stats
)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
...
...
@@ -173,68 +162,7 @@ class EngineCoreProc(EngineCore):
daemon
=
True
).
start
()
# Send Readiness signal to EngineClient.
with
make_zmq_socket
(
ready_path
,
zmq
.
constants
.
PUSH
)
as
ready_socket
:
ready_socket
.
send_string
(
EngineCoreProc
.
READY_STR
)
@
staticmethod
def
wait_for_startup
(
proc
:
BaseProcess
,
ready_path
:
str
,
)
->
None
:
"""Wait until the EngineCore is ready."""
try
:
sync_ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
socket
=
sync_ctx
.
socket
(
zmq
.
constants
.
PULL
)
socket
.
connect
(
ready_path
)
# Wait for EngineCore to send EngineCoreProc.READY_STR.
while
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
logger
.
debug
(
"Waiting for EngineCoreProc to startup."
)
if
not
proc
.
is_alive
():
raise
RuntimeError
(
"EngineCoreProc failed to start."
)
message
=
socket
.
recv_string
()
assert
message
==
EngineCoreProc
.
READY_STR
except
BaseException
as
e
:
logger
.
exception
(
e
)
raise
e
finally
:
sync_ctx
.
destroy
(
linger
=
0
)
@
staticmethod
def
make_engine_core_process
(
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
usage_context
:
UsageContext
,
input_path
:
str
,
output_path
:
str
,
ready_path
:
str
,
)
->
EngineCoreProcHandle
:
context
=
get_mp_context
()
process_kwargs
=
{
"input_path"
:
input_path
,
"output_path"
:
output_path
,
"ready_path"
:
ready_path
,
"vllm_config"
:
vllm_config
,
"executor_class"
:
executor_class
,
"usage_context"
:
usage_context
,
}
# Run EngineCore busy loop in background process.
proc
=
context
.
Process
(
target
=
EngineCoreProc
.
run_engine_core
,
kwargs
=
process_kwargs
)
proc
.
start
()
# Wait for startup
EngineCoreProc
.
wait_for_startup
(
proc
,
ready_path
)
return
EngineCoreProcHandle
(
proc
=
proc
,
ready_path
=
ready_path
,
input_path
=
input_path
,
output_path
=
output_path
)
ready_pipe
.
send
({
"status"
:
"READY"
})
@
staticmethod
def
run_engine_core
(
*
args
,
**
kwargs
):
...
...
@@ -258,6 +186,7 @@ class EngineCoreProc(EngineCore):
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
signal
.
signal
(
signal
.
SIGINT
,
signal_handler
)
parent_process
=
psutil
.
Process
().
parent
()
engine_core
=
None
try
:
engine_core
=
EngineCoreProc
(
*
args
,
**
kwargs
)
...
...
@@ -266,9 +195,10 @@ class EngineCoreProc(EngineCore):
except
SystemExit
:
logger
.
debug
(
"EngineCore interrupted."
)
except
BaseException
as
e
:
logger
.
exception
(
e
)
raise
e
except
Exception
:
traceback
=
get_exception_traceback
()
logger
.
error
(
"EngineCore hit an exception: %s"
,
traceback
)
parent_process
.
send_signal
(
signal
.
SIGQUIT
)
finally
:
if
engine_core
is
not
None
:
...
...
@@ -309,6 +239,9 @@ class EngineCoreProc(EngineCore):
def
_log_stats
(
self
):
"""Log basic stats every LOGGING_TIME_S"""
if
not
self
.
log_stats
:
return
now
=
time
.
time
()
if
now
-
self
.
_last_logging_time
>
LOGGING_TIME_S
:
...
...
@@ -339,7 +272,7 @@ class EngineCoreProc(EngineCore):
decoder_add_req
=
PickleEncoder
()
decoder_abort_req
=
PickleEncoder
()
with
make_
zmq_socket
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
with
zmq_socket
_ctx
(
input_path
,
zmq
.
constants
.
PULL
)
as
socket
:
while
True
:
# (RequestType, RequestData)
type_frame
,
data_frame
=
socket
.
recv_multipart
(
copy
=
False
)
...
...
@@ -367,7 +300,7 @@ class EngineCoreProc(EngineCore):
# Reuse send buffer.
buffer
=
bytearray
()
with
make_
zmq_socket
(
output_path
,
zmq
.
constants
.
PUSH
)
as
socket
:
with
zmq_socket
_ctx
(
output_path
,
zmq
.
constants
.
PUSH
)
as
socket
:
while
True
:
engine_core_outputs
=
self
.
output_queue
.
get
()
outputs
=
EngineCoreOutputs
(
outputs
=
engine_core_outputs
)
...
...
vllm/v1/engine/core_client.py
View file @
df04dffa
import
os
import
weakref
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Type
import
msgspec
import
zmq
import
zmq.asyncio
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_open_zmq_ipc_path
,
kill_process_tree
from
vllm.utils
import
get_open_zmq_ipc_path
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestUnion
)
from
vllm.v1.engine.core
import
(
EngineCore
,
EngineCoreProc
,
EngineCoreProcHandle
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.serial_utils
import
PickleEncoder
from
vllm.v1.utils
import
BackgroundProcHandle
logger
=
init_logger
(
__name__
)
...
...
@@ -31,10 +31,11 @@ class EngineCoreClient:
@
staticmethod
def
make_client
(
*
args
,
multiprocess_mode
:
bool
,
asyncio_mode
:
bool
,
**
kwargs
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
log_stats
:
bool
=
False
,
)
->
"EngineCoreClient"
:
# TODO: support this for debugging purposes.
...
...
@@ -44,12 +45,12 @@ class EngineCoreClient:
"is not currently supported."
)
if
multiprocess_mode
and
asyncio_mode
:
return
AsyncMPClient
(
*
args
,
**
kwarg
s
)
return
AsyncMPClient
(
vllm_config
,
executor_class
,
log_stat
s
)
if
multiprocess_mode
and
not
asyncio_mode
:
return
SyncMPClient
(
*
args
,
**
kwarg
s
)
return
SyncMPClient
(
vllm_config
,
executor_class
,
log_stat
s
)
return
InprocClient
(
*
args
,
**
kwarg
s
)
return
InprocClient
(
vllm_config
,
executor_class
,
log_stat
s
)
def
shutdown
(
self
):
pass
...
...
@@ -128,9 +129,10 @@ class MPClient(EngineCoreClient):
def
__init__
(
self
,
*
args
,
asyncio_mode
:
bool
,
**
kwargs
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
log_stats
:
bool
=
False
,
):
# Serialization setup.
self
.
encoder
=
PickleEncoder
()
...
...
@@ -143,7 +145,6 @@ class MPClient(EngineCoreClient):
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Path for IPC.
ready_path
=
get_open_zmq_ipc_path
()
output_path
=
get_open_zmq_ipc_path
()
input_path
=
get_open_zmq_ipc_path
()
...
...
@@ -156,47 +157,40 @@ class MPClient(EngineCoreClient):
self
.
input_socket
.
bind
(
input_path
)
# Start EngineCore in background process.
self
.
proc_handle
:
Optional
[
EngineCoreProcHandle
]
self
.
proc_handle
=
EngineCoreProc
.
make_engine_core_process
(
*
args
,
input_path
=
input_path
,
# type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path
=
output_path
,
# type: ignore[misc]
ready_path
=
ready_path
,
# type: ignore[misc]
**
kwargs
,
)
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
shutdown
)
self
.
proc_handle
:
Optional
[
BackgroundProcHandle
]
self
.
proc_handle
=
BackgroundProcHandle
(
input_path
=
input_path
,
output_path
=
output_path
,
process_name
=
"EngineCore"
,
target_fn
=
EngineCoreProc
.
run_engine_core
,
process_kwargs
=
{
"vllm_config"
:
vllm_config
,
"executor_class"
:
executor_class
,
"log_stats"
:
log_stats
,
})
def
shutdown
(
self
):
# Shut down the zmq context.
self
.
ctx
.
destroy
(
linger
=
0
)
if
hasattr
(
self
,
"proc_handle"
)
and
self
.
proc_handle
:
# Shutdown the process if needed.
if
self
.
proc_handle
.
proc
.
is_alive
():
self
.
proc_handle
.
proc
.
terminate
()
self
.
proc_handle
.
proc
.
join
(
5
)
if
self
.
proc_handle
.
proc
.
is_alive
():
kill_process_tree
(
self
.
proc_handle
.
proc
.
pid
)
# Remove zmq ipc socket files
ipc_sockets
=
[
self
.
proc_handle
.
ready_path
,
self
.
proc_handle
.
output_path
,
self
.
proc_handle
.
input_path
]
for
ipc_socket
in
ipc_sockets
:
socket_file
=
ipc_socket
.
replace
(
"ipc://"
,
""
)
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
self
.
proc_handle
.
shutdown
()
self
.
proc_handle
=
None
class
SyncMPClient
(
MPClient
):
"""Synchronous client for multi-proc EngineCore."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
asyncio_mode
=
False
,
**
kwargs
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
log_stats
:
bool
=
False
):
super
().
__init__
(
asyncio_mode
=
False
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
)
def
get_output
(
self
)
->
List
[
EngineCoreOutput
]:
...
...
@@ -225,8 +219,16 @@ class SyncMPClient(MPClient):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
asyncio_mode
=
True
,
**
kwargs
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
Executor
],
log_stats
:
bool
=
False
):
super
().
__init__
(
asyncio_mode
=
True
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
log_stats
,
)
async
def
get_output_async
(
self
)
->
List
[
EngineCoreOutput
]:
...
...
vllm/v1/engine/llm_engine.py
View file @
df04dffa
...
...
@@ -72,11 +72,11 @@ class LLMEngine:
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self
.
engine_core
=
EngineCoreClient
.
make_client
(
vllm_config
,
executor_class
,
usage_context
,
multiprocess_mode
=
multiprocess_mode
,
asyncio_mode
=
False
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
False
,
)
@
classmethod
...
...
vllm/v1/executor/multiproc_executor.py
View file @
df04dffa
...
...
@@ -17,13 +17,12 @@ from vllm.distributed import (destroy_distributed_environment,
from
vllm.distributed.device_communicators.shm_broadcast
import
(
Handle
,
MessageQueue
)
from
vllm.executor.multiproc_worker_utils
import
(
_add_prefix
,
get_mp_context
,
set_multiprocessing_worker_envs
)
_add_prefix
,
set_multiprocessing_worker_envs
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_distributed_init_method
,
get_
open_por
t
,
get_open_
zmq_ipc_path
)
from
vllm.utils
import
(
get_distributed_init_method
,
get_
mp_contex
t
,
get_open_
port
,
get_open_zmq_ipc_path
,
zmq_socket_ctx
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
make_zmq_socket
from
vllm.worker.worker_base
import
WorkerWrapperBase
logger
=
init_logger
(
__name__
)
...
...
@@ -250,7 +249,7 @@ class WorkerProc:
worker_response_mq_handle
=
self
.
worker_response_mq
.
export_handle
()
# Send Readiness signal to EngineCore process.
with
make_
zmq_socket
(
ready_path
,
zmq
.
constants
.
PUSH
)
as
ready_socket
:
with
zmq_socket
_ctx
(
ready_path
,
zmq
.
constants
.
PUSH
)
as
ready_socket
:
payload
=
pickle
.
dumps
(
worker_response_mq_handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
...
...
@@ -352,7 +351,7 @@ class WorkerProc:
ready_path
:
str
,
)
->
Optional
[
Handle
]:
"""Wait until the Worker is ready."""
with
make_
zmq_socket
(
ready_path
,
zmq
.
constants
.
PULL
)
as
socket
:
with
zmq_socket
_ctx
(
ready_path
,
zmq
.
constants
.
PULL
)
as
socket
:
# Wait for Worker to send READY.
while
socket
.
poll
(
timeout
=
POLLING_TIMEOUT_MS
)
==
0
:
...
...
vllm/v1/utils.py
View file @
df04dffa
import
os
import
weakref
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Generic
,
Iterator
,
List
,
Optional
,
TypeVar
,
Union
,
overload
)
import
zmq
from
typing
import
(
Any
,
Callable
,
Dict
,
Generic
,
List
,
Optional
,
TypeVar
,
Union
,
overload
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_mp_context
,
kill_process_tree
logger
=
init_logger
(
__name__
)
...
...
@@ -77,27 +77,58 @@ class ConstantList(Generic[T], Sequence):
return
len
(
self
.
_x
)
@
contextmanager
def
make_zmq_socket
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
try
:
socket
=
ctx
.
socket
(
type
)
if
type
==
zmq
.
constants
.
PULL
:
socket
.
connect
(
path
)
elif
type
==
zmq
.
constants
.
PUSH
:
socket
.
bind
(
path
)
else
:
raise
ValueError
(
f
"Unknown Socket Type:
{
type
}
"
)
yield
socket
except
KeyboardInterrupt
:
logger
.
debug
(
"Worker had Keyboard Interrupt."
)
finally
:
ctx
.
destroy
(
linger
=
0
)
class
BackgroundProcHandle
:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def
__init__
(
self
,
input_path
:
str
,
output_path
:
str
,
process_name
:
str
,
target_fn
:
Callable
,
process_kwargs
:
Dict
[
Any
,
Any
],
):
self
.
_finalizer
=
weakref
.
finalize
(
self
,
self
.
shutdown
)
context
=
get_mp_context
()
reader
,
writer
=
context
.
Pipe
(
duplex
=
False
)
assert
(
"ready_pipe"
not
in
process_kwargs
and
"input_path"
not
in
process_kwargs
and
"output_path"
not
in
process_kwargs
)
process_kwargs
[
"ready_pipe"
]
=
writer
process_kwargs
[
"input_path"
]
=
input_path
process_kwargs
[
"output_path"
]
=
output_path
self
.
input_path
=
input_path
self
.
output_path
=
output_path
# Run Detokenizer busy loop in background process.
self
.
proc
=
context
.
Process
(
target
=
target_fn
,
kwargs
=
process_kwargs
)
self
.
proc
.
start
()
# Wait for startup.
if
reader
.
recv
()[
"status"
]
!=
"READY"
:
raise
RuntimeError
(
f
"
{
process_name
}
initialization failed. "
"See root cause above."
)
def
__del__
(
self
):
self
.
shutdown
()
def
shutdown
(
self
):
# Shutdown the process if needed.
if
hasattr
(
self
,
"proc"
)
and
self
.
proc
.
is_alive
():
self
.
proc
.
terminate
()
self
.
proc
.
join
(
5
)
if
self
.
proc
.
is_alive
():
kill_process_tree
(
self
.
proc
.
pid
)
# Remove zmq ipc socket files
ipc_sockets
=
[
self
.
output_path
,
self
.
input_path
]
for
ipc_socket
in
ipc_sockets
:
socket_file
=
ipc_socket
.
replace
(
"ipc://"
,
""
)
if
os
and
os
.
path
.
exists
(
socket_file
):
os
.
remove
(
socket_file
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment