Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
"vscode:/vscode.git/clone" did not exist on "3e051bda82efe351ecfb5bb21de1606accc976d4"
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1019 additions
and
92 deletions
+1019
-92
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/entrypoints/test_api_server_process_manager.py
tests/entrypoints/test_api_server_process_manager.py
+268
-0
tests/utils.py
tests/utils.py
+3
-2
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+0
-1
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+0
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+5
-4
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+7
-7
tests/v1/entrypoints/openai/test_multi_api_servers.py
tests/v1/entrypoints/openai/test_multi_api_servers.py
+171
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+2
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+3
-3
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+0
-1
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+170
-9
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+63
-39
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+5
-0
vllm/utils.py
vllm/utils.py
+5
-1
vllm/v1/core/sched/interface.py
vllm/v1/core/sched/interface.py
+8
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+39
-14
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+4
-4
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+12
-2
vllm/v1/engine/coordinator.py
vllm/v1/engine/coordinator.py
+252
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
2dbe8c07
...
...
@@ -618,9 +618,11 @@ steps:
-
vllm/worker/model_runner.py
-
entrypoints/llm/test_collective_rpc.py
-
tests/v1/test_async_llm_dp.py
-
tests/v1/entrypoints/openai/test_multi_api_servers.py
-
vllm/v1/engine/
commands
:
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
...
...
tests/entrypoints/test_api_server_process_manager.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
socket
import
threading
import
time
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
from
vllm.v1.utils
import
(
APIServerProcessManager
,
wait_for_completion_or_failure
)
# Global variables to control worker behavior
WORKER_RUNTIME_SECONDS
=
0.5
# Mock implementation of run_api_server_worker
def
mock_run_api_server_worker
(
listen_address
,
sock
,
args
,
client_config
=
None
):
"""Mock run_api_server_worker that runs for a specific time."""
print
(
f
"Mock worker started with client_config:
{
client_config
}
"
)
time
.
sleep
(
WORKER_RUNTIME_SECONDS
)
print
(
"Mock worker completed successfully"
)
@
pytest
.
fixture
def
api_server_args
():
"""Fixture to provide arguments for APIServerProcessManager."""
sock
=
socket
.
socket
()
return
{
"target_server_fn"
:
mock_run_api_server_worker
,
"listen_address"
:
"localhost:8000"
,
"sock"
:
sock
,
"args"
:
"test_args"
,
# Simple string to avoid pickling issues
"num_servers"
:
3
,
"input_addresses"
:
[
"tcp://127.0.0.1:5001"
,
"tcp://127.0.0.1:5002"
,
"tcp://127.0.0.1:5003"
],
"output_addresses"
:
[
"tcp://127.0.0.1:6001"
,
"tcp://127.0.0.1:6002"
,
"tcp://127.0.0.1:6003"
],
"stats_update_address"
:
"tcp://127.0.0.1:7000"
,
}
@
pytest
.
mark
.
parametrize
(
"with_stats_update"
,
[
True
,
False
])
def
test_api_server_process_manager_init
(
api_server_args
,
with_stats_update
):
"""Test initializing the APIServerProcessManager."""
# Set the worker runtime to ensure tests complete in reasonable time
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
0.5
# Copy the args to avoid mutating the
args
=
api_server_args
.
copy
()
if
not
with_stats_update
:
args
.
pop
(
"stats_update_address"
)
manager
=
APIServerProcessManager
(
**
args
)
try
:
# Verify the manager was initialized correctly
assert
len
(
manager
.
processes
)
==
3
# Verify all processes are running
for
proc
in
manager
.
processes
:
assert
proc
.
is_alive
()
print
(
"Waiting for processes to run..."
)
time
.
sleep
(
WORKER_RUNTIME_SECONDS
/
2
)
# They should still be alive at this point
for
proc
in
manager
.
processes
:
assert
proc
.
is_alive
()
finally
:
# Always clean up the processes
print
(
"Cleaning up processes..."
)
manager
.
close
()
# Give processes time to terminate
time
.
sleep
(
0.2
)
# Verify all processes were terminated
for
proc
in
manager
.
processes
:
assert
not
proc
.
is_alive
()
@
patch
(
"vllm.entrypoints.cli.serve.run_api_server_worker"
,
mock_run_api_server_worker
)
def
test_wait_for_completion_or_failure
(
api_server_args
):
"""Test that wait_for_completion_or_failure works with failures."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
1.0
# Create the manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
assert
len
(
manager
.
processes
)
==
3
# Create a result capture for the thread
result
:
dict
[
str
,
Optional
[
Exception
]]
=
{
"exception"
:
None
}
def
run_with_exception_capture
():
try
:
wait_for_completion_or_failure
(
api_server_manager
=
manager
)
except
Exception
as
e
:
result
[
"exception"
]
=
e
# Start a thread to run wait_for_completion_or_failure
wait_thread
=
threading
.
Thread
(
target
=
run_with_exception_capture
,
daemon
=
True
)
wait_thread
.
start
()
# Let all processes run for a short time
time
.
sleep
(
0.2
)
# All processes should still be running
assert
all
(
proc
.
is_alive
()
for
proc
in
manager
.
processes
)
# Now simulate a process failure
print
(
"Simulating process failure..."
)
manager
.
processes
[
0
].
terminate
()
# Wait for the wait_for_completion_or_failure
# to detect and handle the failure
# This should trigger it to terminate all other processes
wait_thread
.
join
(
timeout
=
1.0
)
# The wait thread should have exited
assert
not
wait_thread
.
is_alive
()
# Verify that an exception was raised with appropriate error message
assert
result
[
"exception"
]
is
not
None
assert
"died with exit code"
in
str
(
result
[
"exception"
])
# All processes should now be terminated
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(),
f
"Process
{
i
}
should not be alive"
finally
:
manager
.
close
()
time
.
sleep
(
0.2
)
@
pytest
.
mark
.
timeout
(
30
)
def
test_normal_completion
(
api_server_args
):
"""Test that wait_for_completion_or_failure works in normal completion."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
0.1
# Create the manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
# Give processes time to terminate
# wait for processes to complete
remaining_processes
=
manager
.
processes
.
copy
()
while
remaining_processes
:
for
proc
in
remaining_processes
:
if
not
proc
.
is_alive
():
remaining_processes
.
remove
(
proc
)
time
.
sleep
(
0.1
)
# Verify all processes have terminated
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(
),
f
"Process
{
i
}
still alive after terminate()"
# Now call wait_for_completion_or_failure
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure
(
api_server_manager
=
manager
)
finally
:
# Clean up just in case
manager
.
close
()
time
.
sleep
(
0.2
)
@
pytest
.
mark
.
timeout
(
30
)
def
test_external_process_monitoring
(
api_server_args
):
"""Test that wait_for_completion_or_failure handles additional processes."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
100
# Create and start the external process
# (simulates local_engine_manager or coordinator)
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
external_proc
=
spawn_context
.
Process
(
target
=
mock_run_api_server_worker
,
name
=
"MockExternalProcess"
)
external_proc
.
start
()
# Create the class to simulate a coordinator
class
MockCoordinator
:
def
__init__
(
self
,
proc
):
self
.
proc
=
proc
def
close
(
self
):
if
self
.
proc
.
is_alive
():
self
.
proc
.
terminate
()
self
.
proc
.
join
(
timeout
=
0.5
)
# Create a mock coordinator with the external process
mock_coordinator
=
MockCoordinator
(
external_proc
)
# Create the API server manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
# Verify manager initialization
assert
len
(
manager
.
processes
)
==
3
# Create a result capture for the thread
result
:
dict
[
str
,
Optional
[
Exception
]]
=
{
"exception"
:
None
}
def
run_with_exception_capture
():
try
:
wait_for_completion_or_failure
(
api_server_manager
=
manager
,
coordinator
=
mock_coordinator
)
except
Exception
as
e
:
result
[
"exception"
]
=
e
# Start a thread to run wait_for_completion_or_failure
wait_thread
=
threading
.
Thread
(
target
=
run_with_exception_capture
,
daemon
=
True
)
wait_thread
.
start
()
# Terminate the external process to trigger a failure
time
.
sleep
(
0.2
)
external_proc
.
terminate
()
# Wait for the thread to detect the failure
wait_thread
.
join
(
timeout
=
1.0
)
# The wait thread should have completed
assert
not
wait_thread
.
is_alive
(
),
"wait_for_completion_or_failure thread still running"
# Verify that an exception was raised with appropriate error message
assert
result
[
"exception"
]
is
not
None
,
"No exception was raised"
error_message
=
str
(
result
[
"exception"
])
assert
"died with exit code"
in
error_message
,
\
f
"Unexpected error message:
{
error_message
}
"
assert
"MockExternalProcess"
in
error_message
,
\
f
"Error doesn't mention external process:
{
error_message
}
"
# Verify that all API server processes were terminated as a result
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(
),
f
"API server process
{
i
}
was not terminated"
finally
:
# Clean up
manager
.
close
()
mock_coordinator
.
close
()
time
.
sleep
(
0.2
)
tests/utils.py
View file @
2dbe8c07
...
...
@@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.
openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.
cli.serve
import
ServeSubcommand
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
...
@@ -99,7 +99,8 @@ class RemoteOpenAIServer:
parser
=
FlexibleArgumentParser
(
description
=
"vLLM's remote OpenAI server."
)
parser
=
make_arg_parser
(
parser
)
subparsers
=
parser
.
add_subparsers
(
required
=
False
,
dest
=
"subparser"
)
parser
=
ServeSubcommand
().
subparser_init
(
subparsers
)
args
=
parser
.
parse_args
([
"--model"
,
model
,
*
vllm_serve_args
])
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
port
=
int
(
args
.
port
)
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
2dbe8c07
...
...
@@ -45,7 +45,6 @@ def make_request(request_id,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
)
...
...
tests/v1/core/test_prefix_caching.py
View file @
2dbe8c07
...
...
@@ -38,7 +38,6 @@ def make_request(request_id,
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
)
...
...
tests/v1/core/test_scheduler.py
View file @
2dbe8c07
...
...
@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
)
requests
.
append
(
request
)
return
requests
...
...
@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert
running_req
.
num_tokens_with_spec
==
2
+
len
(
spec_tokens
[
i
])
# No draft or accepted tokens counted yet
assert
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
is
None
assert
not
engine_core_outputs
or
(
engine_core_outputs
[
0
].
scheduler_stats
.
spec_decoding_stats
is
None
)
# Schedule the speculated tokens for validation
output
=
scheduler
.
schedule
()
...
...
@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
scheduler_stats
=
engine_core_outputs
.
scheduler_stats
scheduler_stats
=
engine_core_outputs
[
0
].
scheduler_stats
\
if
engine_core_outputs
else
None
if
expected
[
0
]
==
0
:
assert
scheduler_stats
.
spec_decoding_stats
is
None
else
:
...
...
@@ -843,7 +844,7 @@ def _step_until_done(
# We should be in the decode phase now.
assert
num_scheduled_tokens
==
1
assert
len
(
output
.
kv_connector_metadata
.
requests
)
==
0
ecos
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
ecos
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
[
0
]
all_done
=
True
for
eco
in
ecos
.
outputs
:
if
eco
.
finish_reason
is
None
:
...
...
tests/v1/engine/test_engine_core.py
View file @
2dbe8c07
...
...
@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert
len
(
engine_core
.
scheduler
.
running
)
==
4
# Loop through until they are all done.
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
...
...
@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0
.
request_id
=
req1
.
request_id
=
"test"
engine_core
.
add_request
(
req0
)
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
engine_core
.
add_request
(
req1
)
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
...
...
@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
# Loop through until they are all done.
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
...
...
@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output
=
engine_core
.
step_with_batch_queue
()[
0
]
output
=
engine_core
.
step_with_batch_queue
()[
0
]
.
get
(
0
)
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req0
.
request_id
].
num_tokens
==
13
...
...
@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output
=
engine_core
.
step_with_batch_queue
()[
0
]
output
=
engine_core
.
step_with_batch_queue
()[
0
]
.
get
(
0
)
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req1
.
request_id
].
num_tokens
==
13
...
...
@@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
if
step
%
2
==
0
:
# Even steps consumes an output.
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
len
(
output
[
0
]
.
outputs
)
==
1
if
req_id
in
engine_core
.
scheduler
.
requests
:
assert
engine_core
.
scheduler
.
requests
[
req_id
].
num_tokens
==
expected_num_tokens
[
req_id
]
...
...
tests/v1/entrypoints/openai/test_multi_api_servers.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
os
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
MODEL_NAME
=
"ibm-research/PowerMoE-3b"
DP_SIZE
=
os
.
getenv
(
"DP_SIZE"
,
"1"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
():
return
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"128"
,
"--enforce-eager"
,
"--api-server-count"
,
"4"
,
"--data_parallel_size"
,
DP_SIZE
,
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
default_server_args
):
with
RemoteOpenAIServer
(
MODEL_NAME
,
default_server_args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_single_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
)
->
None
:
async
def
make_request
():
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
10
,
temperature
=
1.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert
len
(
choice
.
text
)
>=
1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert
choice
.
finish_reason
in
(
"length"
,
"stop"
)
# Token counts can also vary, so we check they are positive.
assert
completion
.
usage
.
completion_tokens
>
0
assert
completion
.
usage
.
prompt_tokens
>
0
assert
completion
.
usage
.
total_tokens
>
0
return
completion
# Test single request
result
=
await
make_request
()
assert
result
is
not
None
await
asyncio
.
sleep
(
0.5
)
# Send two bursts of requests
num_requests
=
100
tasks
=
[
make_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
assert
all
(
completion
is
not
None
for
completion
in
results
)
await
asyncio
.
sleep
(
0.5
)
tasks
=
[
make_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
assert
all
(
completion
is
not
None
for
completion
in
results
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_completion_streaming
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
)
->
None
:
prompt
=
"What is an LLM?"
async
def
make_streaming_request
():
# Perform a non-streaming request to get the expected full output
single_completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
)
single_output
=
single_completion
.
choices
[
0
].
text
# Perform the streaming request
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
)
chunks
:
list
[
str
]
=
[]
finish_reason_count
=
0
last_chunk
=
None
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
last_chunk
=
chunk
# Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert
finish_reason_count
==
1
,
(
"Finish reason should appear exactly once."
)
assert
last_chunk
is
not
None
,
(
"Stream should have yielded at least one chunk."
)
assert
last_chunk
.
choices
[
0
].
finish_reason
==
"length"
,
"Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert
""
.
join
(
chunks
)
==
single_output
,
"Streamed output should match non-streamed output."
return
True
# Indicate success for this request
# Test single request
result
=
await
make_streaming_request
()
assert
result
is
not
None
await
asyncio
.
sleep
(
0.5
)
# Send two bursts of requests
num_requests
=
100
tasks
=
[
make_streaming_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
,
f
"Expected
{
num_requests
}
results, got
{
len
(
results
)
}
"
assert
all
(
results
),
"Not all streaming requests completed successfully."
await
asyncio
.
sleep
(
0.5
)
tasks
=
[
make_streaming_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
,
f
"Expected
{
num_requests
}
results, got
{
len
(
results
)
}
"
assert
all
(
results
),
"Not all streaming requests completed successfully."
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
2dbe8c07
...
...
@@ -43,7 +43,7 @@ def test_basic_lifecycle():
# Ensure the request is finished after 1 tokens.
assert
request
.
is_finished
()
assert
request
.
status
==
RequestStatus
.
FINISHED_LENGTH_CAPPED
output
=
engine_core_outputs
.
outputs
[
0
]
output
=
engine_core_outputs
[
0
]
.
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
LENGTH
assert
output
.
kv_transfer_params
is
not
None
...
...
@@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
kv_transfer_params
=
eco
.
outputs
[
0
].
kv_transfer_params
kv_transfer_params
=
eco
[
0
]
.
outputs
[
0
].
kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert
(
len
(
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
2dbe8c07
...
...
@@ -61,7 +61,7 @@ def test_basic_lifecycle():
# (1c): update_from_output()
engine_core_outputs
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
assert
len
(
engine_core_outputs
.
outputs
)
==
0
assert
not
engine_core_outputs
or
not
engine_core_outputs
[
0
]
.
outputs
# STEP (2):
# (2a): schedule(): nothing happens!
...
...
@@ -112,7 +112,7 @@ def test_basic_lifecycle():
model_runner_output
)
scheduler
.
schedule
()
outputs
=
engine_core_outputs
.
outputs
outputs
=
engine_core_outputs
[
0
]
.
outputs
assert
len
(
outputs
)
==
1
output
=
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
STOP
...
...
@@ -335,7 +335,7 @@ def test_full_block_prompt():
model_runner_output
)
scheduler
.
schedule
()
outputs
=
engine_core_outputs
.
outputs
outputs
=
engine_core_outputs
[
0
]
.
outputs
assert
len
(
outputs
)
==
1
output
=
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
STOP
...
...
tests/v1/kv_connector/unit/utils.py
View file @
2dbe8c07
...
...
@@ -153,7 +153,6 @@ def create_request(
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
)
req
.
kv_transfer_params
=
kv_transfer_params
return
req
...
...
vllm/entrypoints/cli/serve.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
os
import
signal
import
sys
import
uvloop
import
zmq
import
vllm.envs
as
envs
from
vllm
import
AsyncEngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
(
run_server
,
run_server_worker
,
setup_server
)
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.entrypoints.utils
import
(
VLLM_SERVE_PARSER_EPILOG
,
show_filtered_argument_or_group_from_help
)
from
vllm.executor.multiproc_worker_utils
import
_add_prefix
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
,
zmq_socket_ctx
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_completion_or_failure
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
...
...
@@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
args
.
model
=
args
.
model_tag
if
args
.
headless
:
if
args
.
headless
or
args
.
api_server_count
<
1
:
run_headless
(
args
)
elif
args
.
api_server_count
>
1
:
run_multi_api_server
(
args
)
else
:
# Single API server (this process).
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
...
...
@@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
type
=
int
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
'--api-server-count'
,
'-asc'
,
type
=
int
,
default
=
1
,
help
=
'How many API server processes to run.'
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
...
...
@@ -91,23 +110,26 @@ def cmd_init() -> list[CLISubcommand]:
def
run_headless
(
args
:
argparse
.
Namespace
):
if
args
.
api_server_count
>
1
:
raise
ValueError
(
"api_server_count can't be set in headless mode"
)
# Create the EngineConfig.
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
if
not
envs
.
VLLM_USE_V1
:
raise
Runtim
eError
(
"Headless mode is only supported for V1"
)
raise
Valu
eError
(
"Headless mode is only supported for V1"
)
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
input
_address
=
get_tcp_uri
(
host
,
port
)
handshake
_address
=
get_tcp_uri
(
host
,
port
)
if
local_engine_count
<=
0
:
raise
Runtim
eError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
raise
Valu
eError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def
signal_handler
(
signum
,
frame
):
...
...
@@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
logger
.
info
(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s."
,
local_engine_count
,
input
_address
)
"with head node address %s."
,
local_engine_count
,
handshake
_address
)
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
...
...
@@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
local_start_index
=
0
,
vllm_config
=
vllm_config
,
on_head_node
=
False
,
input_address
=
input
_address
,
handshake_address
=
handshake
_address
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
)
...
...
@@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
finally
:
logger
.
info
(
"Shutting down."
)
engine_manager
.
close
()
def
run_multi_api_server
(
args
:
argparse
.
Namespace
):
assert
not
args
.
headless
num_api_servers
=
args
.
api_server_count
assert
num_api_servers
>
0
if
num_api_servers
>
1
:
setup_multiprocess_prometheus
()
listen_address
,
sock
=
setup_server
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
model_config
=
vllm_config
.
model_config
if
num_api_servers
>
1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"api_server_count > 1 is only supported for V1"
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
raise
ValueError
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1"
)
if
model_config
.
is_multimodal_model
and
not
(
model_config
.
disable_mm_preprocessor_cache
):
logger
.
warning
(
"Multi-model preprocessor cache will be disabled for"
" api_server_count > 1"
)
model_config
.
disable_mm_preprocessor_cache
=
True
parallel_config
=
vllm_config
.
parallel_config
assert
parallel_config
.
data_parallel_rank
==
0
dp_size
=
parallel_config
.
data_parallel_size
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
local_only
=
local_engine_count
==
dp_size
# Set up input and output addresses.
input_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
output_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
addresses
=
EngineZmqAddresses
(
inputs
=
input_addresses
,
outputs
=
output_addresses
,
)
# Set up coordinator for dp > 1.
coordinator
=
None
stats_update_address
=
None
if
dp_size
>
1
:
coordinator
=
DPCoordinator
(
parallel_config
)
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
stats_update_address
=
coordinator
.
get_stats_publish_address
()
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
coordinator
.
proc
.
pid
)
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
with
zmq_socket_ctx
(
handshake_address
,
zmq
.
ROUTER
,
bind
=
True
)
as
handshake_socket
:
# Start local engines.
if
not
local_engine_count
:
local_engine_manager
=
None
else
:
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
handshake_address
=
handshake_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
0
,
local_start_index
=
0
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
# Wait for engine handshakes to complete.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
dp_size
)
]
wait_for_engine_startup
(
handshake_socket
,
addresses
,
core_engines
,
parallel_config
,
vllm_config
.
cache_config
,
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
# Wait for API servers
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
local_engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
def
run_api_server_worker_proc
(
listen_address
,
sock
,
args
,
client_config
=
None
,
**
uvicorn_kwargs
)
->
None
:
"""Entrypoint for individual API server worker processes."""
# Add process-specific prefix to stdout and stderr.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
uvloop
.
run
(
run_server_worker
(
listen_address
,
sock
,
args
,
client_config
,
**
uvicorn_kwargs
))
vllm/entrypoints/openai/api_server.py
View file @
2dbe8c07
...
...
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from
functools
import
partial
from
http
import
HTTPStatus
from
json
import
JSONDecodeError
from
typing
import
Annotated
,
Optional
from
typing
import
Annotated
,
Any
,
Optional
import
prometheus_client
import
regex
as
re
...
...
@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.datastructures
import
State
from
starlette.routing
import
Mount
...
...
@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Device
,
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
@
asynccontextmanager
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
EngineClient
]:
args
:
Namespace
,
client_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncIterator
[
EngineClient
]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
engine_args
,
args
.
disable_frontend_multiprocessing
,
client_config
)
as
engine
:
yield
engine
...
...
@@ -157,6 +163,7 @@ async def build_async_engine_client(
async
def
build_async_engine_client_from_engine_args
(
engine_args
:
AsyncEngineArgs
,
disable_frontend_multiprocessing
:
bool
=
False
,
client_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncIterator
[
EngineClient
]:
"""
Create EngineClient, either:
...
...
@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
from
vllm.v1.engine.async_llm
import
AsyncLLM
async_llm
:
Optional
[
AsyncLLM
]
=
None
client_index
=
client_config
.
pop
(
"client_index"
)
if
client_config
else
0
try
:
async_llm
=
AsyncLLM
.
from_vllm_config
(
vllm_config
=
vllm_config
,
usage_context
=
usage_context
,
disable_log_requests
=
engine_args
.
disable_log_requests
,
disable_log_stats
=
engine_args
.
disable_log_stats
)
disable_log_stats
=
engine_args
.
disable_log_stats
,
client_addresses
=
client_config
,
client_index
=
client_index
)
# Don't keep the dummy data in memory
await
async_llm
.
reset_mm_cache
()
...
...
@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
def
mount_metrics
(
app
:
FastAPI
):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
(
REGISTRY
,
CollectorRegistry
,
make_asgi_app
,
multiprocess
)
from
prometheus_fastapi_instrumentator
import
Instrumentator
registry
=
REGISTRY
prometheus_multiproc_dir_path
=
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
,
None
)
if
prometheus_multiproc_dir_path
is
not
None
:
logger
.
debug
(
"vLLM to use %s as PROMETHEUS_MULTIPROC_DIR"
,
prometheus_multiproc_dir_path
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
"""Mount prometheus metrics to a FastAPI app."""
registry
=
get_prometheus_registry
()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
...
...
@@ -1256,16 +1254,10 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return
sock
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
log_non_default_args
(
args
)
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
def
validate_api_server_args
(
args
):
valid_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
if
args
.
enable_auto_tool_choice
\
and
args
.
tool_call_parser
not
in
valid_tool_parses
:
and
args
.
tool_call_parser
not
in
valid_tool_parses
:
raise
KeyError
(
f
"invalid tool call parser:
{
args
.
tool_call_parser
}
"
f
"(chose from {{
{
','
.
join
(
valid_tool_parses
)
}
}})"
)
...
...
@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f
"invalid reasoning parser:
{
args
.
reasoning_parser
}
"
f
"(chose from {{
{
','
.
join
(
valid_reasoning_parses
)
}
}})"
)
def
setup_server
(
args
):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
log_non_default_args
(
args
)
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
validate_api_server_args
(
args
)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
...
...
@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
addr
,
port
=
sock_addr
is_ssl
=
args
.
ssl_keyfile
and
args
.
ssl_certfile
host_part
=
f
"[
{
addr
}
]"
if
is_valid_ipv6_address
(
addr
)
else
addr
or
"0.0.0.0"
listen_address
=
f
"http
{
's'
if
is_ssl
else
''
}
://
{
host_part
}
:
{
port
}
"
return
listen_address
,
sock
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
"""Run a single-worker API server."""
listen_address
,
sock
=
setup_server
(
args
)
await
run_server_worker
(
listen_address
,
sock
,
args
,
**
uvicorn_kwargs
)
async
def
run_server_worker
(
listen_address
,
sock
,
args
,
client_config
=
None
,
**
uvicorn_kwargs
)
->
None
:
"""Run a single API server worker."""
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
server_index
=
client_config
.
get
(
"client_index"
,
0
)
if
client_config
else
0
async
with
build_async_engine_client
(
args
,
client_config
)
as
engine_client
:
app
=
build_app
(
args
)
vllm_config
=
await
engine_client
.
get_vllm_config
()
await
init_app_state
(
engine_client
,
vllm_config
,
app
.
state
,
args
)
def
_listen_addr
(
a
:
str
)
->
str
:
if
is_valid_ipv6_address
(
a
):
return
'['
+
a
+
']'
return
a
or
"0.0.0.0"
is_ssl
=
args
.
ssl_keyfile
and
args
.
ssl_certfile
logger
.
info
(
"Starting vLLM API server on http%s://%s:%d"
,
"s"
if
is_ssl
else
""
,
_listen_addr
(
sock_addr
[
0
]),
sock_addr
[
1
])
logger
.
info
(
"Starting vLLM API server %d on %s"
,
server_index
,
listen_address
)
shutdown_task
=
await
serve_http
(
app
,
sock
=
sock
,
...
...
vllm/lora/worker_manager.py
View file @
2dbe8c07
...
...
@@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self
.
add_adapter
(
lora
)
def
add_adapter
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if
lora_request
.
lora_int_id
not
in
self
.
list_adapters
():
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
...
...
vllm/utils.py
View file @
2dbe8c07
...
...
@@ -2420,6 +2420,7 @@ def make_zmq_socket(
socket_type
:
Any
,
bind
:
Optional
[
bool
]
=
None
,
identity
:
Optional
[
bytes
]
=
None
,
linger
:
Optional
[
int
]
=
None
,
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
...
...
@@ -2439,7 +2440,7 @@ def make_zmq_socket(
buf_size
=
-
1
# Use system default buffer size
if
bind
is
None
:
bind
=
socket_type
!=
zmq
.
PUSH
bind
=
socket_type
not
in
(
zmq
.
PUSH
,
zmq
.
SUB
,
zmq
.
XSUB
)
if
socket_type
in
(
zmq
.
PULL
,
zmq
.
DEALER
,
zmq
.
ROUTER
):
socket
.
setsockopt
(
zmq
.
RCVHWM
,
0
)
...
...
@@ -2452,6 +2453,9 @@ def make_zmq_socket(
if
identity
is
not
None
:
socket
.
setsockopt
(
zmq
.
IDENTITY
,
identity
)
if
linger
is
not
None
:
socket
.
setsockopt
(
zmq
.
LINGER
,
linger
)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme
,
host
,
_
=
split_zmq_path
(
path
)
...
...
vllm/v1/core/sched/interface.py
View file @
2dbe8c07
...
...
@@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
self
,
scheduler_output
:
"SchedulerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
"EngineCoreOutputs"
:
)
->
dict
[
int
,
"EngineCoreOutputs"
]
:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
...
...
@@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
raise
NotImplementedError
...
...
@@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
"""
raise
NotImplementedError
@
abstractmethod
def
get_request_counts
(
self
)
->
tuple
[
int
,
int
]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise
NotImplementedError
@
abstractmethod
def
make_stats
(
self
)
->
Optional
[
"SchedulerStats"
]:
"""Make a SchedulerStats object for logging.
...
...
vllm/v1/core/sched/scheduler.py
View file @
2dbe8c07
...
...
@@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
self
.
include_finished_set
=
include_finished_set
self
.
finished_req_ids_dict
:
Optional
[
dict
[
int
,
set
[
str
]]]
=
(
defaultdict
(
set
)
if
include_finished_set
else
None
)
# Scheduling constraints.
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
...
...
@@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
self
,
scheduler_output
:
SchedulerOutput
,
model_runner_output
:
ModelRunnerOutput
,
)
->
EngineCoreOutputs
:
)
->
dict
[
int
,
EngineCoreOutputs
]
:
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
...
...
@@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
new_running
:
list
[
Request
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]
]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
...
...
@@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
if
new_token_ids
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
.
append
(
outputs
[
request
.
client_index
]
.
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
...
...
@@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
self
.
_cached_reqs_data
[
req_data
.
req_id
].
append
(
req_data
)
self
.
running
=
new_running
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
outputs
,
scheduler_stats
=
self
.
make_stats
(
spec_decoding_stats
),
)
if
self
.
include_finished_set
:
#TODO currently sending duplicates here, improve this
engine_core_outputs
.
finished_requests
=
(
scheduler_output
.
finished_req_ids
|
self
.
finished_req_ids
)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs
=
{
client_index
:
EngineCoreOutputs
(
outputs
=
outs
)
for
client_index
,
outs
in
outputs
.
items
()
}
finished_req_ids
=
self
.
finished_req_ids_dict
if
finished_req_ids
is
not
None
:
# Include ids of requests that finished since last outputs
# were sent.
for
client_index
,
finished_set
in
finished_req_ids
.
items
():
# Set finished request set in EngineCoreOutputs for this client.
if
(
eco
:
=
engine_core_outputs
.
get
(
client_index
))
is
not
None
:
eco
.
finished_requests
=
finished_set
else
:
engine_core_outputs
[
client_index
]
=
EngineCoreOutputs
(
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
if
engine_core_outputs
:
# Return stats to only one of the front-ends.
next
(
iter
(
engine_core_outputs
.
values
())).
scheduler_stats
=
(
self
.
make_stats
(
spec_decoding_stats
))
return
engine_core_outputs
def
get_request_counts
(
self
)
->
tuple
[
int
,
int
]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return
len
(
self
.
running
),
len
(
self
.
waiting
)
def
add_request
(
self
,
request
:
Request
)
->
None
:
self
.
waiting
.
append
(
request
)
self
.
requests
[
request
.
request_id
]
=
request
...
...
@@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
delay_free_blocks
,
kv_xfer_params
=
self
.
_connector_finished
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
self
.
_cached_reqs_data
.
pop
(
request
.
request_id
,
None
)
self
.
finished_req_ids
.
add
(
request
.
request_id
)
request_id
=
request
.
request_id
self
.
_cached_reqs_data
.
pop
(
request_id
,
None
)
self
.
finished_req_ids
.
add
(
request_id
)
if
self
.
finished_req_ids_dict
is
not
None
:
self
.
finished_req_ids_dict
[
request
.
client_index
].
add
(
request_id
)
if
not
delay_free_blocks
:
self
.
_free_blocks
(
request
)
...
...
vllm/v1/engine/__init__.py
View file @
2dbe8c07
...
...
@@ -44,10 +44,6 @@ class EngineCoreRequest(
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id
:
str
prompt_token_ids
:
list
[
int
]
mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]]
...
...
@@ -59,6 +55,10 @@ class EngineCoreRequest(
lora_request
:
Optional
[
LoRARequest
]
cache_salt
:
Optional
[
str
]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index
:
int
=
0
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
...
...
vllm/v1/engine/async_llm.py
View file @
2dbe8c07
...
...
@@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.loggers
import
(
StatLoggerBase
,
StatLoggerFactory
,
setup_default_loggers
)
from
vllm.v1.metrics.prometheus
import
shutdown_prometheus
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
logger
=
init_logger
(
__name__
)
...
...
@@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
)
->
None
:
"""
Create an AsyncLLM.
...
...
@@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
self
.
log_stats
,
client_addresses
=
client_addresses
,
client_index
=
client_index
,
)
if
self
.
stat_loggers
:
for
stat_logger
in
self
.
stat_loggers
[
0
]:
...
...
@@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
disable_log_requests
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
)
->
"AsyncLLM"
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
...
...
@@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
log_requests
=
not
disable_log_requests
,
log_stats
=
not
disable_log_stats
,
usage_context
=
usage_context
,
client_addresses
=
client_addresses
,
client_index
=
client_index
,
)
@
classmethod
...
...
@@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
def
shutdown
(
self
):
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus
()
if
engine_core
:
=
getattr
(
self
,
"engine_core"
,
None
):
engine_core
.
shutdown
()
...
...
@@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if
stat_loggers
:
assert
outputs
.
scheduler_stats
is
not
None
AsyncLLM
.
_record_stats
(
stat_loggers
[
outputs
.
engine_index
],
scheduler_stats
=
outputs
.
scheduler_stats
,
...
...
@@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
@
staticmethod
def
_record_stats
(
stat_loggers
:
list
[
StatLoggerBase
],
scheduler_stats
:
SchedulerStats
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
],
):
"""static so that it can be used from the output_handler task
...
...
vllm/v1/engine/coordinator.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
time
import
weakref
from
typing
import
Optional
import
msgspec.msgpack
import
zmq
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_mp_context
,
get_open_zmq_ipc_path
,
make_zmq_socket
from
vllm.v1.engine
import
EngineCoreOutputs
,
EngineCoreRequestType
from
vllm.v1.serial_utils
import
MsgpackDecoder
from
vllm.v1.utils
import
get_engine_client_zmq_addr
,
shutdown
logger
=
init_logger
(
__name__
)
class
DPCoordinator
:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def
__init__
(
self
,
parallel_config
:
ParallelConfig
):
# Assume coordinator is colocated with front-end procs.
front_publish_address
=
get_open_zmq_ipc_path
()
dp_size
=
parallel_config
.
data_parallel_size
assert
dp_size
>
1
,
"Coordinator only used for data parallel"
local_only
=
dp_size
==
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
back_publish_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
back_output_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
context
=
get_mp_context
()
self
.
proc
:
multiprocessing
.
Process
=
context
.
Process
(
target
=
CoordinatorProc
.
run_coordinator
,
name
=
"VLLM_DP_Coordinator"
,
kwargs
=
{
"engine_count"
:
parallel_config
.
data_parallel_size
,
"front_publish_address"
:
front_publish_address
,
"back_output_address"
:
back_output_address
,
"back_publish_address"
:
back_publish_address
,
},
daemon
=
True
)
self
.
proc
.
start
()
self
.
stats_publish_address
=
front_publish_address
self
.
coord_in_address
=
back_publish_address
self
.
coord_out_address
=
back_output_address
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
[
self
.
proc
])
def
get_stats_publish_address
(
self
)
->
str
:
return
self
.
stats_publish_address
def
get_engine_socket_addresses
(
self
)
->
tuple
[
str
,
str
]:
"""Returns tuple of ZMQ input address, output address."""
return
self
.
coord_in_address
,
self
.
coord_out_address
def
close
(
self
):
self
.
_finalizer
()
class
EngineState
:
def
__init__
(
self
):
self
.
request_counts
=
[
0
,
0
]
# [waiting, running]
class
CoordinatorProc
:
def
__init__
(
self
,
engine_count
:
int
):
self
.
ctx
=
zmq
.
Context
()
self
.
engines
=
[
EngineState
()
for
_
in
range
(
engine_count
)]
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
stats_changed
=
False
@
staticmethod
def
run_coordinator
(
engine_count
:
int
,
front_publish_address
:
str
,
back_output_address
:
str
,
back_publish_address
:
str
,
):
coordinator
=
CoordinatorProc
(
engine_count
=
engine_count
)
try
:
coordinator
.
process_input_socket
(
front_publish_address
,
back_output_address
,
back_publish_address
,
)
except
KeyboardInterrupt
:
logger
.
info
(
"DP Coordinator process exiting"
)
def
process_input_socket
(
self
,
front_publish_address
:
str
,
back_output_address
:
str
,
back_publish_address
:
str
):
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
with
make_zmq_socket
(
path
=
front_publish_address
,
# IPC
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
XPUB
,
bind
=
True
,
)
as
publish_front
,
make_zmq_socket
(
path
=
back_output_address
,
# IPC or TCP
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
PULL
,
bind
=
True
,
)
as
output_back
,
make_zmq_socket
(
path
=
back_publish_address
,
# IPC or TCP
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
XPUB
,
bind
=
True
,
)
as
publish_back
:
poller
=
zmq
.
Poller
()
poller
.
register
(
publish_front
,
zmq
.
POLLIN
)
poller
.
register
(
output_back
,
zmq
.
POLLIN
)
last_publish_time
=
0
while
True
:
elapsed
=
int
(
time
.
time
()
*
1000
)
-
last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for
=
100
if
self
.
stats_changed
else
3000
events
=
poller
.
poll
(
timeout
=
max
(
0
,
wait_for
-
elapsed
))
if
not
events
:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list
=
self
.
_get_engine_counts
()
to_publish
=
(
engine_req_counts_list
,
self
.
current_wave
,
self
.
engines_running
)
publish_front
.
send
(
msgspec
.
msgpack
.
encode
(
to_publish
))
last_publish_time
=
int
(
time
.
time
()
*
1000
)
self
.
stats_changed
=
False
continue
events
=
dict
(
events
)
if
publish_front
in
events
:
buffer
=
publish_front
.
recv
()
if
buffer
==
b
'
\x01
'
:
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude
,
wave
=
msgspec
.
msgpack
.
decode
(
buffer
)
if
wave
<
self
.
current_wave
:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude
=
None
if
not
self
.
engines_running
:
self
.
engines_running
=
True
self
.
stats_changed
=
True
self
.
_send_start_wave
(
publish_back
,
self
.
current_wave
,
engine_to_exclude
)
if
output_back
in
events
:
# We received a message from one of the engines.
buffer
=
output_back
.
recv
()
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
buffer
)
assert
not
outputs
.
outputs
assert
outputs
.
utility_output
is
None
eng_index
=
outputs
.
engine_index
if
outputs
.
scheduler_stats
:
# 1. Updated request load stats - update our local
# state with these.
stats
=
self
.
engines
[
eng_index
].
request_counts
stats
[
0
]
=
outputs
.
scheduler_stats
.
num_waiting_reqs
stats
[
1
]
=
outputs
.
scheduler_stats
.
num_running_reqs
self
.
stats_changed
=
True
if
(
wave
:
=
outputs
.
wave_complete
)
is
not
None
:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if
self
.
current_wave
<=
wave
:
logger
.
debug
(
"Moving DP wave from %d to %d."
,
self
.
current_wave
,
wave
)
self
.
current_wave
=
wave
+
1
self
.
engines_running
=
False
self
.
stats_changed
=
True
elif
(
wave
:
=
outputs
.
start_wave
)
is
not
None
and
(
wave
>
self
.
current_wave
or
(
wave
==
self
.
current_wave
and
not
self
.
engines_running
)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger
.
debug
(
"Starting wave %d after notification of "
"stale wave request from engine."
,
wave
)
self
.
current_wave
=
wave
self
.
engines_running
=
True
self
.
stats_changed
=
True
self
.
_send_start_wave
(
publish_back
,
wave
,
eng_index
)
@
staticmethod
def
_send_start_wave
(
socket
:
zmq
.
Socket
,
wave
:
int
,
exclude_engine_index
:
Optional
[
int
]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded
=
msgspec
.
msgpack
.
encode
((
wave
,
exclude_engine_index
))
socket
.
send_multipart
(
(
EngineCoreRequestType
.
START_DP_WAVE
.
value
,
wave_encoded
))
def
_get_engine_counts
(
self
)
->
list
[
list
[
int
]]:
"""Return list of [waiting, running] count lists for each engine."""
return
[
e
.
request_counts
for
e
in
self
.
engines
]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment