Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
acd5511b
Unverified
Commit
acd5511b
authored
Sep 16, 2024
by
Nick Hill
Committed by
GitHub
Sep 16, 2024
Browse files
[BugFix] Fix clean shutdown issues (#8492)
parent
837c1968
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
215 additions
and
136 deletions
+215
-136
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+8
-2
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+45
-25
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+13
-8
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+10
-11
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+109
-72
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+7
-1
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+0
-14
vllm/executor/multiproc_worker_utils.py
vllm/executor/multiproc_worker_utils.py
+4
-1
vllm/executor/ray_tpu_executor.py
vllm/executor/ray_tpu_executor.py
+2
-0
vllm/scripts.py
vllm/scripts.py
+2
-2
vllm/utils.py
vllm/utils.py
+15
-0
No files found.
tests/async_engine/test_async_llm_engine.py
View file @
acd5511b
...
@@ -26,6 +26,11 @@ class RequestOutput:
...
@@ -26,6 +26,11 @@ class RequestOutput:
finished
:
bool
=
False
finished
:
bool
=
False
@
dataclass
class
MockModelConfig
:
use_async_output_proc
=
True
class
MockEngine
:
class
MockEngine
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -35,6 +40,7 @@ class MockEngine:
...
@@ -35,6 +40,7 @@ class MockEngine:
self
.
request_id
=
None
self
.
request_id
=
None
# Ugly, remove dependency when possible
# Ugly, remove dependency when possible
self
.
parallel_config
=
ParallelConfig
(
1
,
1
,
False
)
self
.
parallel_config
=
ParallelConfig
(
1
,
1
,
False
)
self
.
model_config
=
MockModelConfig
()
async
def
step_async
(
self
,
virtual_engine
):
async
def
step_async
(
self
,
virtual_engine
):
# PP size is 1, ignore virtual engine
# PP size is 1, ignore virtual engine
...
@@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
...
@@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_new_requests_event
():
async
def
test_new_requests_event
():
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
False
)
engine
=
MockAsyncLLMEngine
()
engine
.
start_background_loop
()
engine
.
start_background_loop
()
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
assert
engine
.
engine
.
step_calls
==
0
assert
engine
.
engine
.
step_calls
==
0
...
@@ -113,7 +119,7 @@ async def test_new_requests_event():
...
@@ -113,7 +119,7 @@ async def test_new_requests_event():
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
add_request_calls
==
3
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
assert
engine
.
engine
.
step_calls
==
old_step_calls
+
1
engine
=
MockAsyncLLMEngine
(
worker_use_ray
=
True
)
engine
=
MockAsyncLLMEngine
()
assert
engine
.
get_model_config
()
is
not
None
assert
engine
.
get_model_config
()
is
not
None
assert
engine
.
get_tokenizer
()
is
not
None
assert
engine
.
get_tokenizer
()
is
not
None
assert
engine
.
get_decoding_config
()
is
not
None
assert
engine
.
get_decoding_config
()
is
not
None
...
...
vllm/engine/async_llm_engine.py
View file @
acd5511b
import
asyncio
import
asyncio
import
time
import
time
import
weakref
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
...
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
...
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
will be automatically started in the generate call.
...
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
...
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
_engine_class
:
Type
[
_AsyncLLMEngine
]
=
_AsyncLLMEngine
def
__init__
(
self
,
def
__init__
(
self
,
worker_use_ray
:
bool
,
*
args
,
*
args
,
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
worker_use_ray
=
worker_use_ray
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_engine_class
(
*
args
,
**
kwargs
)
self
.
engine
=
self
.
_engine_class
(
*
args
,
**
kwargs
)
# This ensures quick processing of request outputs
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# so the append to asyncio queues is not delayed,
# especially for multi-step.
# especially for multi-step.
#
self
.
use_process_request_outputs_callback
=
(
self
.
use_process_request_outputs_callback
=
True
self
.
engine
.
model_config
.
use_async_output_proc
)
if
self
.
use_process_request_outputs_callback
:
if
self
.
use_process_request_outputs_callback
:
self
.
engine
.
process_request_outputs_callback
=
\
self
.
engine
.
process_request_outputs_callback
=
\
self
.
process_request_outputs
weak_bind
(
self
.
process_request_outputs
)
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
# We need to keep a reference to unshielded
...
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
...
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields
# Lazy initialized fields
self
.
_request_tracker
:
RequestTracker
self
.
_request_tracker
:
RequestTracker
def
__del__
(
self
):
if
rt
:
=
getattr
(
self
,
"request_tracker"
,
None
):
# Wake up engine loop so that it will exit cleanly
rt
.
new_requests_event
.
set
()
@
classmethod
@
classmethod
def
_get_executor_cls
(
def
_get_executor_cls
(
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
cls
,
engine_config
:
EngineConfig
)
->
Type
[
ExecutorAsyncBase
]:
...
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
...
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise
TypeError
(
raise
TypeError
(
"distributed_executor_backend must be a subclass of "
"distributed_executor_backend must be a subclass of "
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
f
"ExecutorAsyncBase. Got
{
distributed_executor_backend
}
."
)
if
distributed_executor_backend
.
uses_ray
:
# type: ignore
initialize_ray_cluster
(
engine_config
.
parallel_config
)
executor_class
=
distributed_executor_backend
executor_class
=
distributed_executor_backend
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
elif
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
if
distributed_executor_backend
==
"ray"
:
if
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
from
vllm.executor.ray_tpu_executor
import
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
executor_class
=
RayTPUExecutorAsync
else
:
else
:
...
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
...
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
from
vllm.executor.xpu_executor
import
XPUExecutorAsync
from
vllm.executor.xpu_executor
import
XPUExecutorAsync
executor_class
=
XPUExecutorAsync
executor_class
=
XPUExecutorAsync
elif
distributed_executor_backend
==
"ray"
:
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.multiproc_xpu_executor
import
(
from
vllm.executor.multiproc_xpu_executor
import
(
MultiprocessingXPUExecutorAsync
)
MultiprocessingXPUExecutorAsync
)
executor_class
=
MultiprocessingXPUExecutorAsync
executor_class
=
MultiprocessingXPUExecutorAsync
...
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
...
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise
RuntimeError
(
raise
RuntimeError
(
"Not supported distributed execution model on XPU device."
)
"Not supported distributed execution model on XPU device."
)
elif
distributed_executor_backend
==
"ray"
:
elif
distributed_executor_backend
==
"ray"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
executor_class
=
RayGPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
elif
distributed_executor_backend
==
"mp"
:
...
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
...
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
def
from_engine_args
(
def
from_engine_args
(
cls
,
cls
,
engine_args
:
AsyncEngineArgs
,
engine_args
:
AsyncEngineArgs
,
engine_config
:
Optional
[
EngineConfig
]
=
None
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLMEngine"
:
)
->
"AsyncLLMEngine"
:
"""Creates an async LLM engine from the engine arguments."""
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
# Create the engine configs.
engine_config
=
engine_args
.
create_engine_config
()
if
engine_config
is
None
:
engine_config
=
engine_args
.
create_engine_config
()
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
executor_class
=
cls
.
_get_executor_cls
(
engine_config
)
if
executor_class
.
uses_ray
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
# Create the async LLM engine.
# Create the async LLM engine.
engine
=
cls
(
engine
=
cls
(
executor_class
.
uses_ray
,
**
engine_config
.
to_dict
(),
**
engine_config
.
to_dict
(),
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_requests
=
not
engine_args
.
disable_log_requests
,
...
@@ -628,7 +630,7 @@ class AsyncLLMEngine:
...
@@ -628,7 +630,7 @@ class AsyncLLMEngine:
self
.
_request_tracker
=
RequestTracker
()
self
.
_request_tracker
=
RequestTracker
()
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
self
.
_background_loop_unshielded
=
asyncio
.
get_event_loop
(
).
create_task
(
self
.
run_engine_loop
())
).
create_task
(
self
.
run_engine_loop
(
weakref
.
ref
(
self
)
))
self
.
_background_loop_unshielded
.
add_done_callback
(
self
.
_background_loop_unshielded
.
add_done_callback
(
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
...
@@ -698,9 +700,16 @@ class AsyncLLMEngine:
...
@@ -698,9 +700,16 @@ class AsyncLLMEngine:
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
async
def
_engine_abort
(
self
,
request_ids
:
Iterable
[
str
]):
self
.
engine
.
abort_request
(
request_ids
)
self
.
engine
.
abort_request
(
request_ids
)
async
def
run_engine_loop
(
self
):
@
staticmethod
async
def
run_engine_loop
(
engine_ref
:
ReferenceType
):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine
:
Optional
[
"AsyncLLMEngine"
]
=
engine_ref
()
if
not
engine
:
return
pipeline_parallel_size
=
\
pipeline_parallel_size
=
\
self
.
engine
.
parallel_config
.
pipeline_parallel_size
engine
.
engine
.
parallel_config
.
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
has_requests_in_progress
=
[
False
]
*
pipeline_parallel_size
while
True
:
while
True
:
if
not
any
(
has_requests_in_progress
):
if
not
any
(
has_requests_in_progress
):
...
@@ -711,11 +720,21 @@ class AsyncLLMEngine:
...
@@ -711,11 +720,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
# such as add/remove lora adapters.
await
self
.
engine
.
stop_remote_worker_execution_loop_async
()
await
engine
.
engine
.
stop_remote_worker_execution_loop_async
()
await
self
.
_request_tracker
.
wait_for_new_requests
()
request_tracker
=
engine
.
_request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del
engine
await
asyncio
.
sleep
(
0
)
if
engine_ref
()
is
None
:
return
await
request_tracker
.
wait_for_new_requests
()
engine
=
engine_ref
()
if
not
engine
:
return
logger
.
debug
(
"Got new requests!"
)
logger
.
debug
(
"Got new requests!"
)
requests_in_progress
=
[
requests_in_progress
=
[
asyncio
.
create_task
(
self
.
engine_step
(
ve
))
asyncio
.
create_task
(
engine
.
engine_step
(
ve
))
for
ve
in
range
(
pipeline_parallel_size
)
for
ve
in
range
(
pipeline_parallel_size
)
]
]
has_requests_in_progress
=
[
True
]
*
pipeline_parallel_size
has_requests_in_progress
=
[
True
]
*
pipeline_parallel_size
...
@@ -733,19 +752,20 @@ class AsyncLLMEngine:
...
@@ -733,19 +752,20 @@ class AsyncLLMEngine:
result
=
task
.
result
()
result
=
task
.
result
()
virtual_engine
=
requests_in_progress
.
index
(
task
)
virtual_engine
=
requests_in_progress
.
index
(
task
)
has_unfinished_requests
=
(
has_unfinished_requests
=
(
self
.
engine
.
has_unfinished_requests_for_virtual_engine
(
engine
.
engine
.
has_unfinished_requests_for_virtual_engine
(
virtual_engine
))
virtual_engine
))
if
result
or
has_unfinished_requests
:
if
result
or
has_unfinished_requests
:
requests_in_progress
[
virtual_engine
]
=
(
requests_in_progress
[
virtual_engine
]
=
(
asyncio
.
create_task
(
asyncio
.
create_task
(
self
.
engine_step
(
virtual_engine
)))
engine
.
engine_step
(
virtual_engine
)))
has_requests_in_progress
[
virtual_engine
]
=
True
has_requests_in_progress
[
virtual_engine
]
=
True
else
:
else
:
has_requests_in_progress
[
virtual_engine
]
=
False
has_requests_in_progress
[
virtual_engine
]
=
False
except
asyncio
.
TimeoutError
as
exc
:
except
asyncio
.
TimeoutError
as
exc
:
logger
.
error
(
logger
.
error
(
"Engine iteration timed out. This should never happen!"
)
"Engine iteration timed out. This should never happen!"
)
self
.
set_errored
(
exc
)
engine
.
set_errored
(
exc
)
raise
raise
await
asyncio
.
sleep
(
0
)
await
asyncio
.
sleep
(
0
)
...
...
vllm/engine/llm_engine.py
View file @
acd5511b
import
functools
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
usage_message
)
from
vllm.utils
import
Counter
,
Device
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -382,11 +382,16 @@ class LLMEngine:
...
@@ -382,11 +382,16 @@ class LLMEngine:
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
self
.
async_callbacks
=
[
if
model_config
.
use_async_output_proc
:
functools
.
partial
(
self
.
_process_model_outputs
,
process_model_outputs
=
weak_bind
(
self
.
_process_model_outputs
)
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
self
.
async_callbacks
=
[
]
partial
(
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
else
:
self
.
async_callbacks
=
[]
# Currently used by AsyncLLMEngine to ensure quick append
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
# of request outputs to asyncio queues
...
@@ -869,8 +874,8 @@ class LLMEngine:
...
@@ -869,8 +874,8 @@ class LLMEngine:
"""
"""
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
return
self
.
scheduler
[
virtual_engine
].
has_unfinished_seqs
()
@
staticmethod
def
_process_sequence_group_outputs
(
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
outputs
:
List
[
EmbeddingSequenceGroupOutput
],
)
->
None
:
)
->
None
:
...
...
vllm/entrypoints/launcher.py
View file @
acd5511b
import
asyncio
import
asyncio
import
signal
import
signal
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Any
from
typing
import
Any
,
Optional
import
uvicorn
import
uvicorn
from
fastapi
import
FastAPI
,
Response
from
fastapi
import
FastAPI
,
Request
,
Response
from
vllm
import
envs
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
from
vllm.utils
import
find_process_using_port
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
engine
:
AsyncEngineClie
nt
,
async
def
serve_http
(
app
:
FastAPI
,
limit_concurrency
:
Optional
[
i
nt
]
,
**
uvicorn_kwargs
:
Any
):
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
for
route
in
app
.
routes
:
...
@@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
...
@@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
# Set concurrency limits in uvicorn if running in multiprocessing mode
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if
engine
.
limit_concurrency
is
not
None
:
if
limit_concurrency
is
not
None
:
logger
.
info
(
logger
.
info
(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing"
,
engine
.
limit_concurrency
)
"--disable-frontend-multiprocessing"
,
limit_concurrency
)
uvicorn_kwargs
[
"limit_concurrency"
]
=
engine
.
limit_concurrency
uvicorn_kwargs
[
"limit_concurrency"
]
=
limit_concurrency
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
_add_shutdown_handlers
(
app
,
server
)
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
...
@@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
...
@@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
return
server
.
shutdown
()
return
server
.
shutdown
()
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
,
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
)
->
None
:
engine
:
AsyncEngineClient
)
->
None
:
"""Adds handlers for fatal errors that should crash the server"""
"""Adds handlers for fatal errors that should crash the server"""
@
app
.
exception_handler
(
RuntimeError
)
@
app
.
exception_handler
(
RuntimeError
)
async
def
runtime_error_handler
(
_
,
__
):
async
def
runtime_error_handler
(
request
:
Request
,
__
):
"""On generic runtime error, check to see if the engine has died.
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine
=
request
.
app
.
state
.
engine_client
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
and
not
engine
.
is_running
):
and
not
engine
.
is_running
):
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
...
...
vllm/entrypoints/openai/api_server.py
View file @
acd5511b
...
@@ -4,16 +4,20 @@ import inspect
...
@@ -4,16 +4,20 @@ import inspect
import
multiprocessing
import
multiprocessing
import
os
import
os
import
re
import
re
import
signal
import
tempfile
import
tempfile
from
argparse
import
Namespace
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
AsyncIterator
,
Optional
,
Set
from
typing
import
AsyncIterator
,
Optional
,
Set
import
uvloop
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
starlette.datastructures
import
State
from
starlette.routing
import
Mount
from
starlette.routing
import
Mount
from
typing_extensions
import
assert_never
from
typing_extensions
import
assert_never
...
@@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
...
@@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
async_engine_client
:
AsyncEngineClient
engine_args
:
AsyncEngineArgs
openai_serving_chat
:
OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
...
@@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
...
@@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
@
asynccontextmanager
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
async
def
lifespan
(
app
:
FastAPI
):
try
:
async
def
_force_log
():
if
app
.
state
.
log_stats
:
while
True
:
async_engine_client
=
app
.
state
.
engine_client
await
asyncio
.
sleep
(
10
)
await
async_engine_client
.
do_log_stats
()
async
def
_force_log
():
while
True
:
if
not
engine_args
.
disable_log_stats
:
await
asyncio
.
sleep
(
10
)
task
=
asyncio
.
create_task
(
_force_log
())
await
async_engine_client
.
do_log_stats
()
_running_tasks
.
add
(
task
)
task
.
add_done_callback
(
_running_tasks
.
remove
)
task
=
asyncio
.
create_task
(
_force_log
())
_running_tasks
.
add
(
task
)
yield
task
.
add_done_callback
(
_running_tasks
.
remove
)
else
:
task
=
None
try
:
yield
finally
:
if
task
is
not
None
:
task
.
cancel
()
finally
:
# Ensure app state including engine ref is gc'd
del
app
.
state
@
asynccontextmanager
@
asynccontextmanager
...
@@ -103,16 +111,10 @@ async def build_async_engine_client(
...
@@ -103,16 +111,10 @@ async def build_async_engine_client(
# Context manager to handle async_engine_client lifecycle
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
# Backend itself still global for the silly lil' health handler
global
async_engine_client
async
with
build_async_engine_client_from_engine_args
(
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
async_engine_client
=
engine
# type: ignore[assignment]
yield
engine
yield
engine
...
@@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
...
@@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
if
(
model_is_embedding
(
engine_args
.
model
,
engine_args
.
trust_remote_code
,
if
(
model_is_embedding
(
engine_args
.
model
,
engine_args
.
trust_remote_code
,
engine_args
.
quantization
,
engine_args
.
revision
)
engine_args
.
quantization
,
engine_args
.
revision
)
or
disable_frontend_multiprocessing
):
or
disable_frontend_multiprocessing
):
engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_config
=
engine_args
.
create_engine_config
()
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
uses_ray
=
getattr
(
AsyncLLMEngine
.
_get_executor_cls
(
engine_config
),
try
:
"uses_ray"
,
False
)
yield
engine_client
finally
:
build_engine
=
partial
(
AsyncLLMEngine
.
from_engine_args
,
engine_client
.
shutdown_background_loop
()
engine_args
=
engine_args
,
engine_config
=
engine_config
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
if
uses_ray
:
# Must run in main thread with ray for its signal handlers to work
engine_client
=
build_engine
()
else
:
engine_client
=
await
asyncio
.
get_running_loop
().
run_in_executor
(
None
,
build_engine
)
yield
engine_client
return
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
# Otherwise, use the multiprocessing AsyncLLMEngine.
...
@@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
...
@@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
app
.
routes
.
append
(
metrics_route
)
app
.
routes
.
append
(
metrics_route
)
def
chat
(
request
:
Request
)
->
OpenAIServingChat
:
return
request
.
app
.
state
.
openai_serving_chat
def
completion
(
request
:
Request
)
->
OpenAIServingCompletion
:
return
request
.
app
.
state
.
openai_serving_completion
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
embedding
(
request
:
Request
)
->
OpenAIServingEmbedding
:
return
request
.
app
.
state
.
openai_serving_embedding
def
engine_client
(
request
:
Request
)
->
AsyncEngineClient
:
return
request
.
app
.
state
.
engine_client
@
router
.
get
(
"/health"
)
@
router
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
(
raw_request
:
Request
)
->
Response
:
"""Health check."""
"""Health check."""
await
async_
engine_client
.
check_health
()
await
engine_client
(
raw_request
)
.
check_health
()
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/tokenize"
)
@
router
.
post
(
"/tokenize"
)
async
def
tokenize
(
request
:
TokenizeRequest
):
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_tokenization
.
create_tokenize
(
request
)
generator
=
await
tokenization
(
raw_request
)
.
create_tokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
...
@@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
@
router
.
post
(
"/detokenize"
)
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
async
def
detokenize
(
request
:
DetokenizeRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_tokenization
.
create_detokenize
(
request
)
generator
=
await
tokenization
(
raw_request
)
.
create_detokenize
(
request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
status_code
=
generator
.
code
)
...
@@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
...
@@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
@
router
.
get
(
"/v1/models"
)
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
():
async
def
show_available_models
(
raw_request
:
Request
):
models
=
await
openai_serving_completion
.
show_available_models
()
models
=
await
completion
(
raw_request
)
.
show_available_models
()
return
JSONResponse
(
content
=
models
.
model_dump
())
return
JSONResponse
(
content
=
models
.
model_dump
())
...
@@ -288,7 +320,7 @@ async def show_version():
...
@@ -288,7 +320,7 @@ async def show_version():
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
,
raw_request
:
Request
):
raw_request
:
Request
):
generator
=
await
openai_serving_chat
.
create_chat_completion
(
generator
=
await
chat
(
raw_request
)
.
create_chat_completion
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
...
@@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
...
@@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@
router
.
post
(
"/v1/completions"
)
@
router
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_completion
.
create_completion
(
generator
=
await
completion
(
raw_request
)
.
create_completion
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
@@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
...
@@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@
router
.
post
(
"/v1/embeddings"
)
@
router
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
generator
=
await
openai_serving_embedding
.
create_embedding
(
generator
=
await
embedding
(
raw_request
)
.
create_embedding
(
request
,
raw_request
)
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
return
JSONResponse
(
content
=
generator
.
model_dump
(),
...
@@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
...
@@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!"
)
"used for local development!"
)
@
router
.
post
(
"/start_profile"
)
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
():
async
def
start_profile
(
raw_request
:
Request
):
logger
.
info
(
"Starting profiler..."
)
logger
.
info
(
"Starting profiler..."
)
await
async_
engine_client
.
start_profile
()
await
engine_client
(
raw_request
)
.
start_profile
()
logger
.
info
(
"Profiler started."
)
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
():
async
def
stop_profile
(
raw_request
:
Request
):
logger
.
info
(
"Stopping profiler..."
)
logger
.
info
(
"Stopping profiler..."
)
await
async_
engine_client
.
stop_profile
()
await
engine_client
(
raw_request
)
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
return
Response
(
status_code
=
200
)
...
@@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!"
)
"This should ONLY be used for local development!"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
@
router
.
post
(
"/v1/load_lora_adapter"
)
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
):
async
def
load_lora_adapter
(
request
:
LoadLoraAdapterRequest
,
response
=
await
openai_serving_chat
.
load_lora_adapter
(
request
)
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
load_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
load_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
...
@@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
...
@@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return
Response
(
status_code
=
200
,
content
=
response
)
return
Response
(
status_code
=
200
,
content
=
response
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
@
router
.
post
(
"/v1/unload_lora_adapter"
)
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
):
async
def
unload_lora_adapter
(
request
:
UnloadLoraAdapterRequest
,
response
=
await
openai_serving_chat
.
unload_lora_adapter
(
request
)
raw_request
:
Request
):
response
=
await
chat
(
raw_request
).
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
response
=
await
openai_serving_completion
.
unload_lora_adapter
(
request
)
response
=
await
completion
(
raw_request
)
.
unload_lora_adapter
(
request
)
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
code
)
status_code
=
response
.
code
)
...
@@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
@
app
.
exception_handler
(
RequestValidationError
)
@
app
.
exception_handler
(
RequestValidationError
)
async
def
validation_exception_handler
(
_
,
exc
):
async
def
validation_exception_handler
(
_
,
exc
):
err
=
openai_serving_chat
.
create_error_response
(
message
=
str
(
exc
))
chat
=
app
.
state
.
openai_serving_chat
err
=
chat
.
create_error_response
(
message
=
str
(
exc
))
return
JSONResponse
(
err
.
model_dump
(),
return
JSONResponse
(
err
.
model_dump
(),
status_code
=
HTTPStatus
.
BAD_REQUEST
)
status_code
=
HTTPStatus
.
BAD_REQUEST
)
...
@@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
...
@@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
return
app
return
app
async
def
init_app
(
def
init_app
_state
(
async_engine_client
:
AsyncEngineClient
,
async_engine_client
:
AsyncEngineClient
,
model_config
:
ModelConfig
,
state
:
State
,
args
:
Namespace
,
args
:
Namespace
,
)
->
FastAPI
:
)
->
None
:
app
=
build_app
(
args
)
if
args
.
served_model_name
is
not
None
:
if
args
.
served_model_name
is
not
None
:
served_model_names
=
args
.
served_model_name
served_model_names
=
args
.
served_model_name
else
:
else
:
served_model_names
=
[
args
.
model
]
served_model_names
=
[
args
.
model
]
model_config
=
await
async_engine_client
.
get_model_config
()
if
args
.
disable_log_requests
:
if
args
.
disable_log_requests
:
request_logger
=
None
request_logger
=
None
else
:
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
global
openai_serving_chat
state
.
engine_client
=
async_engine_client
global
openai_serving_completion
state
.
log_stats
=
not
args
.
disable_log_stats
global
openai_serving_embedding
global
openai_serving_tokenization
openai_serving_chat
=
OpenAIServingChat
(
state
.
openai_serving_chat
=
OpenAIServingChat
(
async_engine_client
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
...
@@ -465,7 +496,7 @@ async def init_app(
...
@@ -465,7 +496,7 @@ async def init_app(
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
enable_auto_tools
=
args
.
enable_auto_tool_choice
,
tool_parser
=
args
.
tool_call_parser
)
tool_parser
=
args
.
tool_call_parser
)
openai_serving_completion
=
OpenAIServingCompletion
(
state
.
openai_serving_completion
=
OpenAIServingCompletion
(
async_engine_client
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
...
@@ -474,13 +505,13 @@ async def init_app(
...
@@ -474,13 +505,13 @@ async def init_app(
request_logger
=
request_logger
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
return_tokens_as_token_ids
=
args
.
return_tokens_as_token_ids
,
)
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
state
.
openai_serving_embedding
=
OpenAIServingEmbedding
(
async_engine_client
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
)
)
openai_serving_tokenization
=
OpenAIServingTokenization
(
state
.
openai_serving_tokenization
=
OpenAIServingTokenization
(
async_engine_client
,
async_engine_client
,
model_config
,
model_config
,
served_model_names
,
served_model_names
,
...
@@ -488,25 +519,31 @@ async def init_app(
...
@@ -488,25 +519,31 @@ async def init_app(
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
args
.
chat_template
,
chat_template
=
args
.
chat_template
,
)
)
app
.
root_path
=
args
.
root_path
return
app
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
logger
.
info
(
"args: %s"
,
args
)
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
raise
KeyboardInterrupt
(
"terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
# If None, creation of the client failed and we exit.
# If None, creation of the client failed and we exit.
if
async_engine_client
is
None
:
if
async_engine_client
is
None
:
return
return
app
=
await
init_app
(
async_engine_client
,
args
)
app
=
build_app
(
args
)
model_config
=
await
async_engine_client
.
get_model_config
()
init_app_state
(
async_engine_client
,
model_config
,
app
.
state
,
args
)
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
engine
=
async_engine_client
,
limit_concurrency
=
async_engine_client
.
limit_concurrency
,
host
=
args
.
host
,
host
=
args
.
host
,
port
=
args
.
port
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
log_level
=
args
.
uvicorn_log_level
,
...
@@ -530,4 +567,4 @@ if __name__ == "__main__":
...
@@ -530,4 +567,4 @@ if __name__ == "__main__":
parser
=
make_arg_parser
(
parser
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
asyncio
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/rpc/server.py
View file @
acd5511b
...
@@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
...
@@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
"""Cleanup all resources."""
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
socket
.
close
()
self
.
context
.
destroy
()
self
.
context
.
destroy
()
self
.
engine
.
shutdown_background_loop
()
# Clear the engine reference so that it can be GC'ed.
# Clear the engine reference so that it can be GC'ed.
del
self
.
engine
del
self
.
engine
...
@@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
...
@@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
rpc_path
:
str
):
usage_context
:
UsageContext
,
rpc_path
:
str
):
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
raise
KeyboardInterrupt
(
"AsyncEngineRPCServer terminated"
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
rpc_path
)
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
rpc_path
)
uvloop
.
run
(
run_server
(
server
))
uvloop
.
run
(
run_server
(
server
))
vllm/executor/multiproc_gpu_executor.py
View file @
acd5511b
import
asyncio
import
asyncio
import
os
import
os
import
signal
import
threading
import
weakref
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
...
@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
...
@@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref
=
weakref
.
ref
(
self
)
def
shutdown
(
signum
,
frame
):
if
executor
:
=
ref
():
executor
.
shutdown
()
if
threading
.
current_thread
()
is
threading
.
main_thread
():
signal
.
signal
(
signal
.
SIGINT
,
shutdown
)
signal
.
signal
(
signal
.
SIGTERM
,
shutdown
)
self
.
driver_worker
=
self
.
_create_worker
(
self
.
driver_worker
=
self
.
_create_worker
(
distributed_init_method
=
distributed_init_method
)
distributed_init_method
=
distributed_init_method
)
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"init_device"
)
...
...
vllm/executor/multiproc_worker_utils.py
View file @
acd5511b
...
@@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
...
@@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
logger
.
error
(
"Worker %s pid %s died, exit code: %s"
,
process
.
name
,
process
.
pid
,
process
.
exitcode
)
process
.
name
,
process
.
pid
,
process
.
exitcode
)
# Cleanup any remaining workers
# Cleanup any remaining workers
logger
.
info
(
"Killing local vLLM worker processes"
)
if
logger
:
logger
.
info
(
"Killing local vLLM worker processes"
)
for
worker
in
self
.
workers
:
for
worker
in
self
.
workers
:
worker
.
kill_worker
()
worker
.
kill_worker
()
# Must be done after worker task queues are all closed
# Must be done after worker task queues are all closed
...
@@ -221,6 +222,8 @@ def _run_worker_process(
...
@@ -221,6 +222,8 @@ def _run_worker_process(
try
:
try
:
executor
=
getattr
(
worker
,
method
)
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
output
=
executor
(
*
args
,
**
kwargs
)
except
KeyboardInterrupt
:
break
except
BaseException
as
e
:
except
BaseException
as
e
:
tb
=
traceback
.
format_exc
()
tb
=
traceback
.
format_exc
()
logger
.
error
(
logger
.
error
(
...
...
vllm/executor/ray_tpu_executor.py
View file @
acd5511b
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
...
@@ -26,6 +26,8 @@ logger = init_logger(__name__)
class
RayTPUExecutor
(
TPUExecutor
):
class
RayTPUExecutor
(
TPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# This is non-None when the execute model loop is running
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...
...
vllm/scripts.py
View file @
acd5511b
# The CLI entrypoint to vLLM.
# The CLI entrypoint to vLLM.
import
argparse
import
argparse
import
asyncio
import
os
import
os
import
signal
import
signal
import
sys
import
sys
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
uvloop
from
openai
import
OpenAI
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
from
openai.types.chat
import
ChatCompletionMessageParam
...
@@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
...
@@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
# EngineArgs expects the model name to be passed as --model.
args
.
model
=
args
.
model_tag
args
.
model
=
args
.
model_tag
asyncio
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
def
interactive_cli
(
args
:
argparse
.
Namespace
)
->
None
:
def
interactive_cli
(
args
:
argparse
.
Namespace
)
->
None
:
...
...
vllm/utils.py
View file @
acd5511b
...
@@ -12,6 +12,7 @@ import tempfile
...
@@ -12,6 +12,7 @@ import tempfile
import
threading
import
threading
import
uuid
import
uuid
import
warnings
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
...
@@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
...
@@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
return
_cuda_device_count_stateless
(
envs
.
CUDA_VISIBLE_DEVICES
)
return
_cuda_device_count_stateless
(
envs
.
CUDA_VISIBLE_DEVICES
)
def
weak_bind
(
bound_method
:
Callable
[...,
Any
],
)
->
Callable
[...,
None
]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
instance is collected."""
ref
=
weakref
.
ref
(
bound_method
.
__self__
)
# type: ignore[attr-defined]
unbound
=
bound_method
.
__func__
# type: ignore[attr-defined]
def
weak_bound
(
*
args
,
**
kwargs
)
->
None
:
if
inst
:
=
ref
():
unbound
(
inst
,
*
args
,
**
kwargs
)
return
weak_bound
#From: https://stackoverflow.com/a/4104188/2749989
#From: https://stackoverflow.com/a/4104188/2749989
def
run_once
(
f
:
Callable
[
P
,
None
])
->
Callable
[
P
,
None
]:
def
run_once
(
f
:
Callable
[
P
,
None
])
->
Callable
[
P
,
None
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment