Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2dbe8c07
Unverified
Commit
2dbe8c07
authored
May 30, 2025
by
Nick Hill
Committed by
GitHub
May 30, 2025
Browse files
[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
parent
84ec470f
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1019 additions
and
92 deletions
+1019
-92
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/entrypoints/test_api_server_process_manager.py
tests/entrypoints/test_api_server_process_manager.py
+268
-0
tests/utils.py
tests/utils.py
+3
-2
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+0
-1
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+0
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+5
-4
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+7
-7
tests/v1/entrypoints/openai/test_multi_api_servers.py
tests/v1/entrypoints/openai/test_multi_api_servers.py
+171
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+2
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+3
-3
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+0
-1
vllm/entrypoints/cli/serve.py
vllm/entrypoints/cli/serve.py
+170
-9
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+63
-39
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+5
-0
vllm/utils.py
vllm/utils.py
+5
-1
vllm/v1/core/sched/interface.py
vllm/v1/core/sched/interface.py
+8
-2
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+39
-14
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+4
-4
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+12
-2
vllm/v1/engine/coordinator.py
vllm/v1/engine/coordinator.py
+252
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
2dbe8c07
...
@@ -618,9 +618,11 @@ steps:
...
@@ -618,9 +618,11 @@ steps:
-
vllm/worker/model_runner.py
-
vllm/worker/model_runner.py
-
entrypoints/llm/test_collective_rpc.py
-
entrypoints/llm/test_collective_rpc.py
-
tests/v1/test_async_llm_dp.py
-
tests/v1/test_async_llm_dp.py
-
tests/v1/entrypoints/openai/test_multi_api_servers.py
-
vllm/v1/engine/
-
vllm/v1/engine/
commands
:
commands
:
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
-
DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s entrypoints/llm/test_collective_rpc.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
...
...
tests/entrypoints/test_api_server_process_manager.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
socket
import
threading
import
time
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
from
vllm.v1.utils
import
(
APIServerProcessManager
,
wait_for_completion_or_failure
)
# Global variables to control worker behavior
WORKER_RUNTIME_SECONDS
=
0.5
# Mock implementation of run_api_server_worker
def
mock_run_api_server_worker
(
listen_address
,
sock
,
args
,
client_config
=
None
):
"""Mock run_api_server_worker that runs for a specific time."""
print
(
f
"Mock worker started with client_config:
{
client_config
}
"
)
time
.
sleep
(
WORKER_RUNTIME_SECONDS
)
print
(
"Mock worker completed successfully"
)
@
pytest
.
fixture
def
api_server_args
():
"""Fixture to provide arguments for APIServerProcessManager."""
sock
=
socket
.
socket
()
return
{
"target_server_fn"
:
mock_run_api_server_worker
,
"listen_address"
:
"localhost:8000"
,
"sock"
:
sock
,
"args"
:
"test_args"
,
# Simple string to avoid pickling issues
"num_servers"
:
3
,
"input_addresses"
:
[
"tcp://127.0.0.1:5001"
,
"tcp://127.0.0.1:5002"
,
"tcp://127.0.0.1:5003"
],
"output_addresses"
:
[
"tcp://127.0.0.1:6001"
,
"tcp://127.0.0.1:6002"
,
"tcp://127.0.0.1:6003"
],
"stats_update_address"
:
"tcp://127.0.0.1:7000"
,
}
@
pytest
.
mark
.
parametrize
(
"with_stats_update"
,
[
True
,
False
])
def
test_api_server_process_manager_init
(
api_server_args
,
with_stats_update
):
"""Test initializing the APIServerProcessManager."""
# Set the worker runtime to ensure tests complete in reasonable time
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
0.5
# Copy the args to avoid mutating the
args
=
api_server_args
.
copy
()
if
not
with_stats_update
:
args
.
pop
(
"stats_update_address"
)
manager
=
APIServerProcessManager
(
**
args
)
try
:
# Verify the manager was initialized correctly
assert
len
(
manager
.
processes
)
==
3
# Verify all processes are running
for
proc
in
manager
.
processes
:
assert
proc
.
is_alive
()
print
(
"Waiting for processes to run..."
)
time
.
sleep
(
WORKER_RUNTIME_SECONDS
/
2
)
# They should still be alive at this point
for
proc
in
manager
.
processes
:
assert
proc
.
is_alive
()
finally
:
# Always clean up the processes
print
(
"Cleaning up processes..."
)
manager
.
close
()
# Give processes time to terminate
time
.
sleep
(
0.2
)
# Verify all processes were terminated
for
proc
in
manager
.
processes
:
assert
not
proc
.
is_alive
()
@
patch
(
"vllm.entrypoints.cli.serve.run_api_server_worker"
,
mock_run_api_server_worker
)
def
test_wait_for_completion_or_failure
(
api_server_args
):
"""Test that wait_for_completion_or_failure works with failures."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
1.0
# Create the manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
assert
len
(
manager
.
processes
)
==
3
# Create a result capture for the thread
result
:
dict
[
str
,
Optional
[
Exception
]]
=
{
"exception"
:
None
}
def
run_with_exception_capture
():
try
:
wait_for_completion_or_failure
(
api_server_manager
=
manager
)
except
Exception
as
e
:
result
[
"exception"
]
=
e
# Start a thread to run wait_for_completion_or_failure
wait_thread
=
threading
.
Thread
(
target
=
run_with_exception_capture
,
daemon
=
True
)
wait_thread
.
start
()
# Let all processes run for a short time
time
.
sleep
(
0.2
)
# All processes should still be running
assert
all
(
proc
.
is_alive
()
for
proc
in
manager
.
processes
)
# Now simulate a process failure
print
(
"Simulating process failure..."
)
manager
.
processes
[
0
].
terminate
()
# Wait for the wait_for_completion_or_failure
# to detect and handle the failure
# This should trigger it to terminate all other processes
wait_thread
.
join
(
timeout
=
1.0
)
# The wait thread should have exited
assert
not
wait_thread
.
is_alive
()
# Verify that an exception was raised with appropriate error message
assert
result
[
"exception"
]
is
not
None
assert
"died with exit code"
in
str
(
result
[
"exception"
])
# All processes should now be terminated
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(),
f
"Process
{
i
}
should not be alive"
finally
:
manager
.
close
()
time
.
sleep
(
0.2
)
@
pytest
.
mark
.
timeout
(
30
)
def
test_normal_completion
(
api_server_args
):
"""Test that wait_for_completion_or_failure works in normal completion."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
0.1
# Create the manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
# Give processes time to terminate
# wait for processes to complete
remaining_processes
=
manager
.
processes
.
copy
()
while
remaining_processes
:
for
proc
in
remaining_processes
:
if
not
proc
.
is_alive
():
remaining_processes
.
remove
(
proc
)
time
.
sleep
(
0.1
)
# Verify all processes have terminated
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(
),
f
"Process
{
i
}
still alive after terminate()"
# Now call wait_for_completion_or_failure
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure
(
api_server_manager
=
manager
)
finally
:
# Clean up just in case
manager
.
close
()
time
.
sleep
(
0.2
)
@
pytest
.
mark
.
timeout
(
30
)
def
test_external_process_monitoring
(
api_server_args
):
"""Test that wait_for_completion_or_failure handles additional processes."""
global
WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS
=
100
# Create and start the external process
# (simulates local_engine_manager or coordinator)
spawn_context
=
multiprocessing
.
get_context
(
"spawn"
)
external_proc
=
spawn_context
.
Process
(
target
=
mock_run_api_server_worker
,
name
=
"MockExternalProcess"
)
external_proc
.
start
()
# Create the class to simulate a coordinator
class
MockCoordinator
:
def
__init__
(
self
,
proc
):
self
.
proc
=
proc
def
close
(
self
):
if
self
.
proc
.
is_alive
():
self
.
proc
.
terminate
()
self
.
proc
.
join
(
timeout
=
0.5
)
# Create a mock coordinator with the external process
mock_coordinator
=
MockCoordinator
(
external_proc
)
# Create the API server manager
manager
=
APIServerProcessManager
(
**
api_server_args
)
try
:
# Verify manager initialization
assert
len
(
manager
.
processes
)
==
3
# Create a result capture for the thread
result
:
dict
[
str
,
Optional
[
Exception
]]
=
{
"exception"
:
None
}
def
run_with_exception_capture
():
try
:
wait_for_completion_or_failure
(
api_server_manager
=
manager
,
coordinator
=
mock_coordinator
)
except
Exception
as
e
:
result
[
"exception"
]
=
e
# Start a thread to run wait_for_completion_or_failure
wait_thread
=
threading
.
Thread
(
target
=
run_with_exception_capture
,
daemon
=
True
)
wait_thread
.
start
()
# Terminate the external process to trigger a failure
time
.
sleep
(
0.2
)
external_proc
.
terminate
()
# Wait for the thread to detect the failure
wait_thread
.
join
(
timeout
=
1.0
)
# The wait thread should have completed
assert
not
wait_thread
.
is_alive
(
),
"wait_for_completion_or_failure thread still running"
# Verify that an exception was raised with appropriate error message
assert
result
[
"exception"
]
is
not
None
,
"No exception was raised"
error_message
=
str
(
result
[
"exception"
])
assert
"died with exit code"
in
error_message
,
\
f
"Unexpected error message:
{
error_message
}
"
assert
"MockExternalProcess"
in
error_message
,
\
f
"Error doesn't mention external process:
{
error_message
}
"
# Verify that all API server processes were terminated as a result
for
i
,
proc
in
enumerate
(
manager
.
processes
):
assert
not
proc
.
is_alive
(
),
f
"API server process
{
i
}
was not terminated"
finally
:
# Clean up
manager
.
close
()
mock_coordinator
.
close
()
time
.
sleep
(
0.2
)
tests/utils.py
View file @
2dbe8c07
...
@@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
...
@@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.
openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.
cli.serve
import
ServeSubcommand
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
...
@@ -99,7 +99,8 @@ class RemoteOpenAIServer:
...
@@ -99,7 +99,8 @@ class RemoteOpenAIServer:
parser
=
FlexibleArgumentParser
(
parser
=
FlexibleArgumentParser
(
description
=
"vLLM's remote OpenAI server."
)
description
=
"vLLM's remote OpenAI server."
)
parser
=
make_arg_parser
(
parser
)
subparsers
=
parser
.
add_subparsers
(
required
=
False
,
dest
=
"subparser"
)
parser
=
ServeSubcommand
().
subparser_init
(
subparsers
)
args
=
parser
.
parse_args
([
"--model"
,
model
,
*
vllm_serve_args
])
args
=
parser
.
parse_args
([
"--model"
,
model
,
*
vllm_serve_args
])
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
host
=
str
(
args
.
host
or
'localhost'
)
self
.
port
=
int
(
args
.
port
)
self
.
port
=
int
(
args
.
port
)
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
2dbe8c07
...
@@ -45,7 +45,6 @@ def make_request(request_id,
...
@@ -45,7 +45,6 @@ def make_request(request_id,
multi_modal_placeholders
=
mm_positions
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
eos_token_id
=
100
,
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
cache_salt
=
cache_salt
,
)
)
...
...
tests/v1/core/test_prefix_caching.py
View file @
2dbe8c07
...
@@ -38,7 +38,6 @@ def make_request(request_id,
...
@@ -38,7 +38,6 @@ def make_request(request_id,
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
prompt_logprobs
=
prompt_logprobs
),
eos_token_id
=
100
,
eos_token_id
=
100
,
arrival_time
=
0
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
cache_salt
=
cache_salt
,
)
)
...
...
tests/v1/core/test_scheduler.py
View file @
2dbe8c07
...
@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
...
@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
multi_modal_placeholders
=
mm_position
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
)
)
requests
.
append
(
request
)
requests
.
append
(
request
)
return
requests
return
requests
...
@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
...
@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert
running_req
.
num_tokens_with_spec
==
2
+
len
(
spec_tokens
[
i
])
assert
running_req
.
num_tokens_with_spec
==
2
+
len
(
spec_tokens
[
i
])
# No draft or accepted tokens counted yet
# No draft or accepted tokens counted yet
assert
engine_core_outputs
.
scheduler_stats
.
spec_decoding_stats
is
None
assert
not
engine_core_outputs
or
(
engine_core_outputs
[
0
].
scheduler_stats
.
spec_decoding_stats
is
None
)
# Schedule the speculated tokens for validation
# Schedule the speculated tokens for validation
output
=
scheduler
.
schedule
()
output
=
scheduler
.
schedule
()
...
@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
...
@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
engine_core_outputs
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
model_runner_output
)
scheduler_stats
=
engine_core_outputs
.
scheduler_stats
scheduler_stats
=
engine_core_outputs
[
0
].
scheduler_stats
\
if
engine_core_outputs
else
None
if
expected
[
0
]
==
0
:
if
expected
[
0
]
==
0
:
assert
scheduler_stats
.
spec_decoding_stats
is
None
assert
scheduler_stats
.
spec_decoding_stats
is
None
else
:
else
:
...
@@ -843,7 +844,7 @@ def _step_until_done(
...
@@ -843,7 +844,7 @@ def _step_until_done(
# We should be in the decode phase now.
# We should be in the decode phase now.
assert
num_scheduled_tokens
==
1
assert
num_scheduled_tokens
==
1
assert
len
(
output
.
kv_connector_metadata
.
requests
)
==
0
assert
len
(
output
.
kv_connector_metadata
.
requests
)
==
0
ecos
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
ecos
=
scheduler
.
update_from_output
(
output
,
model_runner_output
)
[
0
]
all_done
=
True
all_done
=
True
for
eco
in
ecos
.
outputs
:
for
eco
in
ecos
.
outputs
:
if
eco
.
finish_reason
is
None
:
if
eco
.
finish_reason
is
None
:
...
...
tests/v1/engine/test_engine_core.py
View file @
2dbe8c07
...
@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
...
@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert
len
(
engine_core
.
scheduler
.
running
)
==
4
assert
len
(
engine_core
.
scheduler
.
running
)
==
4
# Loop through until they are all done.
# Loop through until they are all done.
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
...
@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
...
@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0
.
request_id
=
req1
.
request_id
=
"test"
req0
.
request_id
=
req1
.
request_id
=
"test"
engine_core
.
add_request
(
req0
)
engine_core
.
add_request
(
req0
)
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
pass
engine_core
.
add_request
(
req1
)
engine_core
.
add_request
(
req1
)
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
...
@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
...
@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
# Loop through until they are all done.
# Loop through until they are all done.
while
len
(
engine_core
.
step
()[
0
].
outputs
)
>
0
:
while
(
outs
:
=
engine_core
.
step
()[
0
].
get
(
0
))
and
outs
.
outputs
:
pass
pass
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
...
@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
4
assert
scheduler_output
.
num_scheduled_tokens
[
1
]
==
4
# Batch queue is full. Finish Batch 2. Get first token of req0.
# Batch queue is full. Finish Batch 2. Get first token of req0.
output
=
engine_core
.
step_with_batch_queue
()[
0
]
output
=
engine_core
.
step_with_batch_queue
()[
0
]
.
get
(
0
)
assert
output
is
not
None
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req0
.
request_id
].
num_tokens
==
13
assert
engine_core
.
scheduler
.
requests
[
req0
.
request_id
].
num_tokens
==
13
...
@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
1
assert
scheduler_output
.
num_scheduled_tokens
[
0
]
==
1
# Batch queue is full. Finish Batch 3. Get first token of req1.
# Batch queue is full. Finish Batch 3. Get first token of req1.
output
=
engine_core
.
step_with_batch_queue
()[
0
]
output
=
engine_core
.
step_with_batch_queue
()[
0
]
.
get
(
0
)
assert
output
is
not
None
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
len
(
output
.
outputs
)
==
1
assert
engine_core
.
scheduler
.
requests
[
req1
.
request_id
].
num_tokens
==
13
assert
engine_core
.
scheduler
.
requests
[
req1
.
request_id
].
num_tokens
==
13
...
@@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
if
step
%
2
==
0
:
if
step
%
2
==
0
:
# Even steps consumes an output.
# Even steps consumes an output.
assert
output
is
not
None
assert
output
is
not
None
assert
len
(
output
.
outputs
)
==
1
assert
len
(
output
[
0
]
.
outputs
)
==
1
if
req_id
in
engine_core
.
scheduler
.
requests
:
if
req_id
in
engine_core
.
scheduler
.
requests
:
assert
engine_core
.
scheduler
.
requests
[
assert
engine_core
.
scheduler
.
requests
[
req_id
].
num_tokens
==
expected_num_tokens
[
req_id
]
req_id
].
num_tokens
==
expected_num_tokens
[
req_id
]
...
...
tests/v1/entrypoints/openai/test_multi_api_servers.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
os
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
MODEL_NAME
=
"ibm-research/PowerMoE-3b"
DP_SIZE
=
os
.
getenv
(
"DP_SIZE"
,
"1"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
():
return
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"128"
,
"--enforce-eager"
,
"--api-server-count"
,
"4"
,
"--data_parallel_size"
,
DP_SIZE
,
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
(
default_server_args
):
with
RemoteOpenAIServer
(
MODEL_NAME
,
default_server_args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_single_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
)
->
None
:
async
def
make_request
():
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
10
,
temperature
=
1.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert
len
(
choice
.
text
)
>=
1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert
choice
.
finish_reason
in
(
"length"
,
"stop"
)
# Token counts can also vary, so we check they are positive.
assert
completion
.
usage
.
completion_tokens
>
0
assert
completion
.
usage
.
prompt_tokens
>
0
assert
completion
.
usage
.
total_tokens
>
0
return
completion
# Test single request
result
=
await
make_request
()
assert
result
is
not
None
await
asyncio
.
sleep
(
0.5
)
# Send two bursts of requests
num_requests
=
100
tasks
=
[
make_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
assert
all
(
completion
is
not
None
for
completion
in
results
)
await
asyncio
.
sleep
(
0.5
)
tasks
=
[
make_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
assert
all
(
completion
is
not
None
for
completion
in
results
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_completion_streaming
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
)
->
None
:
prompt
=
"What is an LLM?"
async
def
make_streaming_request
():
# Perform a non-streaming request to get the expected full output
single_completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
)
single_output
=
single_completion
.
choices
[
0
].
text
# Perform the streaming request
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
)
chunks
:
list
[
str
]
=
[]
finish_reason_count
=
0
last_chunk
=
None
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
last_chunk
=
chunk
# Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert
finish_reason_count
==
1
,
(
"Finish reason should appear exactly once."
)
assert
last_chunk
is
not
None
,
(
"Stream should have yielded at least one chunk."
)
assert
last_chunk
.
choices
[
0
].
finish_reason
==
"length"
,
"Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert
""
.
join
(
chunks
)
==
single_output
,
"Streamed output should match non-streamed output."
return
True
# Indicate success for this request
# Test single request
result
=
await
make_streaming_request
()
assert
result
is
not
None
await
asyncio
.
sleep
(
0.5
)
# Send two bursts of requests
num_requests
=
100
tasks
=
[
make_streaming_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
,
f
"Expected
{
num_requests
}
results, got
{
len
(
results
)
}
"
assert
all
(
results
),
"Not all streaming requests completed successfully."
await
asyncio
.
sleep
(
0.5
)
tasks
=
[
make_streaming_request
()
for
_
in
range
(
num_requests
)]
results
=
await
asyncio
.
gather
(
*
tasks
)
assert
len
(
results
)
==
num_requests
,
f
"Expected
{
num_requests
}
results, got
{
len
(
results
)
}
"
assert
all
(
results
),
"Not all streaming requests completed successfully."
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
2dbe8c07
...
@@ -43,7 +43,7 @@ def test_basic_lifecycle():
...
@@ -43,7 +43,7 @@ def test_basic_lifecycle():
# Ensure the request is finished after 1 tokens.
# Ensure the request is finished after 1 tokens.
assert
request
.
is_finished
()
assert
request
.
is_finished
()
assert
request
.
status
==
RequestStatus
.
FINISHED_LENGTH_CAPPED
assert
request
.
status
==
RequestStatus
.
FINISHED_LENGTH_CAPPED
output
=
engine_core_outputs
.
outputs
[
0
]
output
=
engine_core_outputs
[
0
]
.
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
LENGTH
assert
output
.
finish_reason
==
FinishReason
.
LENGTH
assert
output
.
kv_transfer_params
is
not
None
assert
output
.
kv_transfer_params
is
not
None
...
@@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
...
@@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_remote
])
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
eco
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
kv_transfer_params
=
eco
.
outputs
[
0
].
kv_transfer_params
kv_transfer_params
=
eco
[
0
]
.
outputs
[
0
].
kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
# Ensure we send all block ids, even if there is a cache hit.
assert
(
len
(
assert
(
len
(
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
2dbe8c07
...
@@ -61,7 +61,7 @@ def test_basic_lifecycle():
...
@@ -61,7 +61,7 @@ def test_basic_lifecycle():
# (1c): update_from_output()
# (1c): update_from_output()
engine_core_outputs
=
scheduler
.
update_from_output
(
scheduler_output
,
engine_core_outputs
=
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
model_runner_output
)
assert
len
(
engine_core_outputs
.
outputs
)
==
0
assert
not
engine_core_outputs
or
not
engine_core_outputs
[
0
]
.
outputs
# STEP (2):
# STEP (2):
# (2a): schedule(): nothing happens!
# (2a): schedule(): nothing happens!
...
@@ -112,7 +112,7 @@ def test_basic_lifecycle():
...
@@ -112,7 +112,7 @@ def test_basic_lifecycle():
model_runner_output
)
model_runner_output
)
scheduler
.
schedule
()
scheduler
.
schedule
()
outputs
=
engine_core_outputs
.
outputs
outputs
=
engine_core_outputs
[
0
]
.
outputs
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
output
=
outputs
[
0
]
output
=
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
STOP
assert
output
.
finish_reason
==
FinishReason
.
STOP
...
@@ -335,7 +335,7 @@ def test_full_block_prompt():
...
@@ -335,7 +335,7 @@ def test_full_block_prompt():
model_runner_output
)
model_runner_output
)
scheduler
.
schedule
()
scheduler
.
schedule
()
outputs
=
engine_core_outputs
.
outputs
outputs
=
engine_core_outputs
[
0
]
.
outputs
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
output
=
outputs
[
0
]
output
=
outputs
[
0
]
assert
output
.
finish_reason
==
FinishReason
.
STOP
assert
output
.
finish_reason
==
FinishReason
.
STOP
...
...
tests/v1/kv_connector/unit/utils.py
View file @
2dbe8c07
...
@@ -153,7 +153,6 @@ def create_request(
...
@@ -153,7 +153,6 @@ def create_request(
multi_modal_placeholders
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
arrival_time
=
0
,
)
)
req
.
kv_transfer_params
=
kv_transfer_params
req
.
kv_transfer_params
=
kv_transfer_params
return
req
return
req
...
...
vllm/entrypoints/cli/serve.py
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
argparse
import
os
import
signal
import
signal
import
sys
import
uvloop
import
uvloop
import
zmq
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
AsyncEngineArgs
from
vllm
import
AsyncEngineArgs
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.cli.types
import
CLISubcommand
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
(
run_server
,
run_server_worker
,
setup_server
)
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
validate_parsed_serve_args
)
from
vllm.entrypoints.utils
import
(
VLLM_SERVE_PARSER_EPILOG
,
from
vllm.entrypoints.utils
import
(
VLLM_SERVE_PARSER_EPILOG
,
show_filtered_argument_or_group_from_help
)
show_filtered_argument_or_group_from_help
)
from
vllm.executor.multiproc_worker_utils
import
_add_prefix
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
from
vllm.utils
import
FlexibleArgumentParser
,
get_tcp_uri
,
zmq_socket_ctx
from
vllm.v1.engine.coordinator
import
DPCoordinator
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core
import
EngineCoreProc
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.engine.core_client
import
CoreEngineProcManager
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.prometheus
import
setup_multiprocess_prometheus
from
vllm.v1.utils
import
(
APIServerProcessManager
,
CoreEngine
,
EngineZmqAddresses
,
get_engine_client_zmq_addr
,
wait_for_completion_or_failure
,
wait_for_engine_startup
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
...
@@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
if
hasattr
(
args
,
'model_tag'
)
and
args
.
model_tag
is
not
None
:
args
.
model
=
args
.
model_tag
args
.
model
=
args
.
model_tag
if
args
.
headless
:
if
args
.
headless
or
args
.
api_server_count
<
1
:
run_headless
(
args
)
run_headless
(
args
)
elif
args
.
api_server_count
>
1
:
run_multi_api_server
(
args
)
else
:
else
:
# Single API server (this process).
uvloop
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
def
validate
(
self
,
args
:
argparse
.
Namespace
)
->
None
:
...
@@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
...
@@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
type
=
int
,
type
=
int
,
default
=
0
,
default
=
0
,
help
=
'Starting data parallel rank for secondary nodes.'
)
help
=
'Starting data parallel rank for secondary nodes.'
)
serve_parser
.
add_argument
(
'--api-server-count'
,
'-asc'
,
type
=
int
,
default
=
1
,
help
=
'How many API server processes to run.'
)
serve_parser
.
add_argument
(
serve_parser
.
add_argument
(
"--config"
,
"--config"
,
type
=
str
,
type
=
str
,
...
@@ -91,22 +110,25 @@ def cmd_init() -> list[CLISubcommand]:
...
@@ -91,22 +110,25 @@ def cmd_init() -> list[CLISubcommand]:
def
run_headless
(
args
:
argparse
.
Namespace
):
def
run_headless
(
args
:
argparse
.
Namespace
):
if
args
.
api_server_count
>
1
:
raise
ValueError
(
"api_server_count can't be set in headless mode"
)
# Create the EngineConfig.
# Create the EngineConfig.
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
if
not
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
Runtim
eError
(
"Headless mode is only supported for V1"
)
raise
Valu
eError
(
"Headless mode is only supported for V1"
)
parallel_config
=
vllm_config
.
parallel_config
parallel_config
=
vllm_config
.
parallel_config
local_engine_count
=
parallel_config
.
data_parallel_size_local
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
host
=
parallel_config
.
data_parallel_master_ip
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
port
=
engine_args
.
data_parallel_rpc_port
# add to config too
input
_address
=
get_tcp_uri
(
host
,
port
)
handshake
_address
=
get_tcp_uri
(
host
,
port
)
if
local_engine_count
<=
0
:
if
local_engine_count
<=
0
:
raise
Runtim
eError
(
"data_parallel_size_local must be > 0 in "
raise
Valu
eError
(
"data_parallel_size_local must be > 0 in "
"headless mode"
)
"headless mode"
)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
# Catch SIGTERM and SIGINT to allow graceful shutdown.
...
@@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
...
@@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
logger
.
info
(
logger
.
info
(
"Launching %d data parallel engine(s) in headless mode, "
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s."
,
local_engine_count
,
input
_address
)
"with head node address %s."
,
local_engine_count
,
handshake
_address
)
# Create the engines.
# Create the engines.
engine_manager
=
CoreEngineProcManager
(
engine_manager
=
CoreEngineProcManager
(
...
@@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
...
@@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
local_start_index
=
0
,
local_start_index
=
0
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
on_head_node
=
False
,
on_head_node
=
False
,
input_address
=
input
_address
,
handshake_address
=
handshake
_address
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
log_stats
=
not
engine_args
.
disable_log_stats
,
)
)
...
@@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
...
@@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
finally
:
finally
:
logger
.
info
(
"Shutting down."
)
logger
.
info
(
"Shutting down."
)
engine_manager
.
close
()
engine_manager
.
close
()
def
run_multi_api_server
(
args
:
argparse
.
Namespace
):
assert
not
args
.
headless
num_api_servers
=
args
.
api_server_count
assert
num_api_servers
>
0
if
num_api_servers
>
1
:
setup_multiprocess_prometheus
()
listen_address
,
sock
=
setup_server
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
usage_context
=
UsageContext
.
OPENAI_API_SERVER
vllm_config
=
engine_args
.
create_engine_config
(
usage_context
=
usage_context
)
model_config
=
vllm_config
.
model_config
if
num_api_servers
>
1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"api_server_count > 1 is only supported for V1"
)
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
raise
ValueError
(
"VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1"
)
if
model_config
.
is_multimodal_model
and
not
(
model_config
.
disable_mm_preprocessor_cache
):
logger
.
warning
(
"Multi-model preprocessor cache will be disabled for"
" api_server_count > 1"
)
model_config
.
disable_mm_preprocessor_cache
=
True
parallel_config
=
vllm_config
.
parallel_config
assert
parallel_config
.
data_parallel_rank
==
0
dp_size
=
parallel_config
.
data_parallel_size
local_engine_count
=
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
local_only
=
local_engine_count
==
dp_size
# Set up input and output addresses.
input_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
output_addresses
=
[
get_engine_client_zmq_addr
(
local_only
,
host
)
for
_
in
range
(
num_api_servers
)
]
addresses
=
EngineZmqAddresses
(
inputs
=
input_addresses
,
outputs
=
output_addresses
,
)
# Set up coordinator for dp > 1.
coordinator
=
None
stats_update_address
=
None
if
dp_size
>
1
:
coordinator
=
DPCoordinator
(
parallel_config
)
addresses
.
coordinator_input
,
addresses
.
coordinator_output
=
(
coordinator
.
get_engine_socket_addresses
())
stats_update_address
=
coordinator
.
get_stats_publish_address
()
logger
.
info
(
"Started DP Coordinator process (PID: %d)"
,
coordinator
.
proc
.
pid
)
handshake_address
=
get_engine_client_zmq_addr
(
local_only
,
host
,
parallel_config
.
data_parallel_rpc_port
)
with
zmq_socket_ctx
(
handshake_address
,
zmq
.
ROUTER
,
bind
=
True
)
as
handshake_socket
:
# Start local engines.
if
not
local_engine_count
:
local_engine_manager
=
None
else
:
local_engine_manager
=
CoreEngineProcManager
(
EngineCoreProc
.
run_engine_core
,
vllm_config
=
vllm_config
,
executor_class
=
Executor
.
get_class
(
vllm_config
),
log_stats
=
not
engine_args
.
disable_log_stats
,
handshake_address
=
handshake_address
,
on_head_node
=
True
,
local_engine_count
=
local_engine_count
,
start_index
=
0
,
local_start_index
=
0
)
# Start API servers using the manager
api_server_manager
=
APIServerProcessManager
(
target_server_fn
=
run_api_server_worker_proc
,
listen_address
=
listen_address
,
sock
=
sock
,
args
=
args
,
num_servers
=
num_api_servers
,
input_addresses
=
input_addresses
,
output_addresses
=
output_addresses
,
stats_update_address
=
stats_update_address
)
# Wait for engine handshakes to complete.
core_engines
=
[
CoreEngine
(
index
=
i
,
local
=
(
i
<
local_engine_count
))
for
i
in
range
(
dp_size
)
]
wait_for_engine_startup
(
handshake_socket
,
addresses
,
core_engines
,
parallel_config
,
vllm_config
.
cache_config
,
local_engine_manager
,
coordinator
.
proc
if
coordinator
else
None
,
)
# Wait for API servers
wait_for_completion_or_failure
(
api_server_manager
=
api_server_manager
,
local_engine_manager
=
local_engine_manager
,
coordinator
=
coordinator
)
def
run_api_server_worker_proc
(
listen_address
,
sock
,
args
,
client_config
=
None
,
**
uvicorn_kwargs
)
->
None
:
"""Entrypoint for individual API server worker processes."""
# Add process-specific prefix to stdout and stderr.
from
multiprocessing
import
current_process
process_name
=
current_process
().
name
pid
=
os
.
getpid
()
_add_prefix
(
sys
.
stdout
,
process_name
,
pid
)
_add_prefix
(
sys
.
stderr
,
process_name
,
pid
)
uvloop
.
run
(
run_server_worker
(
listen_address
,
sock
,
args
,
client_config
,
**
uvicorn_kwargs
))
vllm/entrypoints/openai/api_server.py
View file @
2dbe8c07
...
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
...
@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from
functools
import
partial
from
functools
import
partial
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
json
import
JSONDecodeError
from
json
import
JSONDecodeError
from
typing
import
Annotated
,
Optional
from
typing
import
Annotated
,
Any
,
Optional
import
prometheus_client
import
prometheus_client
import
regex
as
re
import
regex
as
re
...
@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
...
@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
from
prometheus_fastapi_instrumentator
import
Instrumentator
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.datastructures
import
State
from
starlette.datastructures
import
State
from
starlette.routing
import
Mount
from
starlette.routing
import
Mount
...
@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
...
@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Device
,
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
from
vllm.utils
import
(
Device
,
FlexibleArgumentParser
,
get_open_zmq_ipc_path
,
is_valid_ipv6_address
,
set_ulimit
)
is_valid_ipv6_address
,
set_ulimit
)
from
vllm.v1.metrics.prometheus
import
get_prometheus_registry
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
...
@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
@
asynccontextmanager
@
asynccontextmanager
async
def
build_async_engine_client
(
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
EngineClient
]:
args
:
Namespace
,
client_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncIterator
[
EngineClient
]:
# Context manager to handle engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
# Ensures everything is shutdown and cleaned up on error/exit
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
async
with
build_async_engine_client_from_engine_args
(
async
with
build_async_engine_client_from_engine_args
(
engine_args
,
args
.
disable_frontend_multiprocessing
)
as
engine
:
engine_args
,
args
.
disable_frontend_multiprocessing
,
client_config
)
as
engine
:
yield
engine
yield
engine
...
@@ -157,6 +163,7 @@ async def build_async_engine_client(
...
@@ -157,6 +163,7 @@ async def build_async_engine_client(
async
def
build_async_engine_client_from_engine_args
(
async
def
build_async_engine_client_from_engine_args
(
engine_args
:
AsyncEngineArgs
,
engine_args
:
AsyncEngineArgs
,
disable_frontend_multiprocessing
:
bool
=
False
,
disable_frontend_multiprocessing
:
bool
=
False
,
client_config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
AsyncIterator
[
EngineClient
]:
)
->
AsyncIterator
[
EngineClient
]:
"""
"""
Create EngineClient, either:
Create EngineClient, either:
...
@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
...
@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
async_llm
:
Optional
[
AsyncLLM
]
=
None
async_llm
:
Optional
[
AsyncLLM
]
=
None
client_index
=
client_config
.
pop
(
"client_index"
)
if
client_config
else
0
try
:
try
:
async_llm
=
AsyncLLM
.
from_vllm_config
(
async_llm
=
AsyncLLM
.
from_vllm_config
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
usage_context
=
usage_context
,
usage_context
=
usage_context
,
disable_log_requests
=
engine_args
.
disable_log_requests
,
disable_log_requests
=
engine_args
.
disable_log_requests
,
disable_log_stats
=
engine_args
.
disable_log_stats
)
disable_log_stats
=
engine_args
.
disable_log_stats
,
client_addresses
=
client_config
,
client_index
=
client_index
)
# Don't keep the dummy data in memory
# Don't keep the dummy data in memory
await
async_llm
.
reset_mm_cache
()
await
async_llm
.
reset_mm_cache
()
...
@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
...
@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
def
mount_metrics
(
app
:
FastAPI
):
def
mount_metrics
(
app
:
FastAPI
):
# Lazy import for prometheus multiprocessing.
"""Mount prometheus metrics to a FastAPI app."""
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
(
REGISTRY
,
CollectorRegistry
,
make_asgi_app
,
multiprocess
)
from
prometheus_fastapi_instrumentator
import
Instrumentator
registry
=
REGISTRY
prometheus_multiproc_dir_path
=
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
,
None
)
registry
=
get_prometheus_registry
()
if
prometheus_multiproc_dir_path
is
not
None
:
logger
.
debug
(
"vLLM to use %s as PROMETHEUS_MULTIPROC_DIR"
,
prometheus_multiproc_dir_path
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
# `response_class=PrometheusResponse` is needed to return an HTTP response
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
...
@@ -1256,13 +1254,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
...
@@ -1256,13 +1254,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return
sock
return
sock
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
def
validate_api_server_args
(
args
):
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
log_non_default_args
(
args
)
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
valid_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
valid_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
if
args
.
enable_auto_tool_choice
\
if
args
.
enable_auto_tool_choice
\
and
args
.
tool_call_parser
not
in
valid_tool_parses
:
and
args
.
tool_call_parser
not
in
valid_tool_parses
:
...
@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f
"invalid reasoning parser:
{
args
.
reasoning_parser
}
"
f
"invalid reasoning parser:
{
args
.
reasoning_parser
}
"
f
"(chose from {{
{
','
.
join
(
valid_reasoning_parses
)
}
}})"
)
f
"(chose from {{
{
','
.
join
(
valid_reasoning_parses
)
}
}})"
)
def
setup_server
(
args
):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
log_non_default_args
(
args
)
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
validate_api_server_args
(
args
)
# workaround to make sure that we bind the port before the engine is set up.
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
# see https://github.com/vllm-project/vllm/issues/8204
...
@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
...
@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
signal
.
signal
(
signal
.
SIGTERM
,
signal_handler
)
async
with
build_async_engine_client
(
args
)
as
engine_client
:
addr
,
port
=
sock_addr
is_ssl
=
args
.
ssl_keyfile
and
args
.
ssl_certfile
host_part
=
f
"[
{
addr
}
]"
if
is_valid_ipv6_address
(
addr
)
else
addr
or
"0.0.0.0"
listen_address
=
f
"http
{
's'
if
is_ssl
else
''
}
://
{
host_part
}
:
{
port
}
"
return
listen_address
,
sock
async
def
run_server
(
args
,
**
uvicorn_kwargs
)
->
None
:
"""Run a single-worker API server."""
listen_address
,
sock
=
setup_server
(
args
)
await
run_server_worker
(
listen_address
,
sock
,
args
,
**
uvicorn_kwargs
)
async
def
run_server_worker
(
listen_address
,
sock
,
args
,
client_config
=
None
,
**
uvicorn_kwargs
)
->
None
:
"""Run a single API server worker."""
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
server_index
=
client_config
.
get
(
"client_index"
,
0
)
if
client_config
else
0
async
with
build_async_engine_client
(
args
,
client_config
)
as
engine_client
:
app
=
build_app
(
args
)
app
=
build_app
(
args
)
vllm_config
=
await
engine_client
.
get_vllm_config
()
vllm_config
=
await
engine_client
.
get_vllm_config
()
await
init_app_state
(
engine_client
,
vllm_config
,
app
.
state
,
args
)
await
init_app_state
(
engine_client
,
vllm_config
,
app
.
state
,
args
)
def
_listen_addr
(
a
:
str
)
->
str
:
logger
.
info
(
"Starting vLLM API server %d on %s"
,
server_index
,
if
is_valid_ipv6_address
(
a
):
listen_address
)
return
'['
+
a
+
']'
return
a
or
"0.0.0.0"
is_ssl
=
args
.
ssl_keyfile
and
args
.
ssl_certfile
logger
.
info
(
"Starting vLLM API server on http%s://%s:%d"
,
"s"
if
is_ssl
else
""
,
_listen_addr
(
sock_addr
[
0
]),
sock_addr
[
1
])
shutdown_task
=
await
serve_http
(
shutdown_task
=
await
serve_http
(
app
,
app
,
sock
=
sock
,
sock
=
sock
,
...
...
vllm/lora/worker_manager.py
View file @
2dbe8c07
...
@@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...
@@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self
.
add_adapter
(
lora
)
self
.
add_adapter
(
lora
)
def
add_adapter
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_adapter
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if
lora_request
.
lora_int_id
not
in
self
.
list_adapters
():
if
lora_request
.
lora_int_id
not
in
self
.
list_adapters
():
# Load the new adapter first to ensure it is actually valid, before
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
# evicting any existing adapters.
...
...
vllm/utils.py
View file @
2dbe8c07
...
@@ -2420,6 +2420,7 @@ def make_zmq_socket(
...
@@ -2420,6 +2420,7 @@ def make_zmq_socket(
socket_type
:
Any
,
socket_type
:
Any
,
bind
:
Optional
[
bool
]
=
None
,
bind
:
Optional
[
bool
]
=
None
,
identity
:
Optional
[
bytes
]
=
None
,
identity
:
Optional
[
bytes
]
=
None
,
linger
:
Optional
[
int
]
=
None
,
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
)
->
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]:
# type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
"""Make a ZMQ socket with the proper bind/connect semantics."""
...
@@ -2439,7 +2440,7 @@ def make_zmq_socket(
...
@@ -2439,7 +2440,7 @@ def make_zmq_socket(
buf_size
=
-
1
# Use system default buffer size
buf_size
=
-
1
# Use system default buffer size
if
bind
is
None
:
if
bind
is
None
:
bind
=
socket_type
!=
zmq
.
PUSH
bind
=
socket_type
not
in
(
zmq
.
PUSH
,
zmq
.
SUB
,
zmq
.
XSUB
)
if
socket_type
in
(
zmq
.
PULL
,
zmq
.
DEALER
,
zmq
.
ROUTER
):
if
socket_type
in
(
zmq
.
PULL
,
zmq
.
DEALER
,
zmq
.
ROUTER
):
socket
.
setsockopt
(
zmq
.
RCVHWM
,
0
)
socket
.
setsockopt
(
zmq
.
RCVHWM
,
0
)
...
@@ -2452,6 +2453,9 @@ def make_zmq_socket(
...
@@ -2452,6 +2453,9 @@ def make_zmq_socket(
if
identity
is
not
None
:
if
identity
is
not
None
:
socket
.
setsockopt
(
zmq
.
IDENTITY
,
identity
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
identity
)
if
linger
is
not
None
:
socket
.
setsockopt
(
zmq
.
LINGER
,
linger
)
# Determine if the path is a TCP socket with an IPv6 address.
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
# Enable IPv6 on the zmq socket if so.
scheme
,
host
,
_
=
split_zmq_path
(
path
)
scheme
,
host
,
_
=
split_zmq_path
(
path
)
...
...
vllm/v1/core/sched/interface.py
View file @
2dbe8c07
...
@@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
...
@@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
"EngineCoreOutputs"
:
)
->
dict
[
int
,
"EngineCoreOutputs"
]
:
"""Update the scheduler state based on the model runner output.
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
This method is called after the model runner has processed the scheduled
...
@@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
...
@@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
for each request.
for each request.
Returns:
Returns:
A EngineCoreOutputs object containing the outputs for each request.
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
...
@@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
get_request_counts
(
self
)
->
tuple
[
int
,
int
]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
make_stats
(
self
)
->
Optional
[
"SchedulerStats"
]:
def
make_stats
(
self
)
->
Optional
[
"SchedulerStats"
]:
"""Make a SchedulerStats object for logging.
"""Make a SchedulerStats object for logging.
...
...
vllm/v1/core/sched/scheduler.py
View file @
2dbe8c07
...
@@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
...
@@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
# by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently.
# case to track request lifetimes efficiently.
self
.
include_finished_set
=
include_finished_set
self
.
finished_req_ids_dict
:
Optional
[
dict
[
int
,
set
[
str
]]]
=
(
defaultdict
(
set
)
if
include_finished_set
else
None
)
# Scheduling constraints.
# Scheduling constraints.
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
self
.
max_num_running_reqs
=
self
.
scheduler_config
.
max_num_seqs
...
@@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
...
@@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
self
,
self
,
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
model_runner_output
:
ModelRunnerOutput
,
model_runner_output
:
ModelRunnerOutput
,
)
->
EngineCoreOutputs
:
)
->
dict
[
int
,
EngineCoreOutputs
]
:
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
sampled_token_ids
=
model_runner_output
.
sampled_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
spec_token_ids
=
model_runner_output
.
spec_token_ids
logprobs
=
model_runner_output
.
logprobs
logprobs
=
model_runner_output
.
logprobs
...
@@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
...
@@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
new_running
:
list
[
Request
]
=
[]
new_running
:
list
[
Request
]
=
[]
outputs
:
list
[
EngineCoreOutput
]
=
[]
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]
]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
...
@@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
...
@@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
if
new_token_ids
or
kv_transfer_params
:
if
new_token_ids
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
# Add EngineCoreOutput for this Request.
outputs
.
append
(
outputs
[
request
.
client_index
]
.
append
(
EngineCoreOutput
(
EngineCoreOutput
(
request_id
=
req_id
,
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
new_token_ids
=
new_token_ids
,
...
@@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
...
@@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
self
.
_cached_reqs_data
[
req_data
.
req_id
].
append
(
req_data
)
self
.
_cached_reqs_data
[
req_data
.
req_id
].
append
(
req_data
)
self
.
running
=
new_running
self
.
running
=
new_running
engine_core_outputs
=
EngineCoreOutputs
(
outputs
=
outputs
,
# Create EngineCoreOutputs for all clients that have requests with
scheduler_stats
=
self
.
make_stats
(
spec_decoding_stats
),
# outputs in this step.
)
engine_core_outputs
=
{
if
self
.
include_finished_set
:
client_index
:
EngineCoreOutputs
(
outputs
=
outs
)
#TODO currently sending duplicates here, improve this
for
client_index
,
outs
in
outputs
.
items
()
engine_core_outputs
.
finished_requests
=
(
}
scheduler_output
.
finished_req_ids
|
self
.
finished_req_ids
)
finished_req_ids
=
self
.
finished_req_ids_dict
if
finished_req_ids
is
not
None
:
# Include ids of requests that finished since last outputs
# were sent.
for
client_index
,
finished_set
in
finished_req_ids
.
items
():
# Set finished request set in EngineCoreOutputs for this client.
if
(
eco
:
=
engine_core_outputs
.
get
(
client_index
))
is
not
None
:
eco
.
finished_requests
=
finished_set
else
:
engine_core_outputs
[
client_index
]
=
EngineCoreOutputs
(
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
if
engine_core_outputs
:
# Return stats to only one of the front-ends.
next
(
iter
(
engine_core_outputs
.
values
())).
scheduler_stats
=
(
self
.
make_stats
(
spec_decoding_stats
))
return
engine_core_outputs
return
engine_core_outputs
def
get_request_counts
(
self
)
->
tuple
[
int
,
int
]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return
len
(
self
.
running
),
len
(
self
.
waiting
)
def
add_request
(
self
,
request
:
Request
)
->
None
:
def
add_request
(
self
,
request
:
Request
)
->
None
:
self
.
waiting
.
append
(
request
)
self
.
waiting
.
append
(
request
)
self
.
requests
[
request
.
request_id
]
=
request
self
.
requests
[
request
.
request_id
]
=
request
...
@@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
...
@@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
delay_free_blocks
,
kv_xfer_params
=
self
.
_connector_finished
(
request
)
delay_free_blocks
,
kv_xfer_params
=
self
.
_connector_finished
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
self
.
encoder_cache_manager
.
free
(
request
)
self
.
_cached_reqs_data
.
pop
(
request
.
request_id
,
None
)
request_id
=
request
.
request_id
self
.
finished_req_ids
.
add
(
request
.
request_id
)
self
.
_cached_reqs_data
.
pop
(
request_id
,
None
)
self
.
finished_req_ids
.
add
(
request_id
)
if
self
.
finished_req_ids_dict
is
not
None
:
self
.
finished_req_ids_dict
[
request
.
client_index
].
add
(
request_id
)
if
not
delay_free_blocks
:
if
not
delay_free_blocks
:
self
.
_free_blocks
(
request
)
self
.
_free_blocks
(
request
)
...
...
vllm/v1/engine/__init__.py
View file @
2dbe8c07
...
@@ -44,10 +44,6 @@ class EngineCoreRequest(
...
@@ -44,10 +44,6 @@ class EngineCoreRequest(
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id
:
str
request_id
:
str
prompt_token_ids
:
list
[
int
]
prompt_token_ids
:
list
[
int
]
mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]]
mm_inputs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargs
]]]
...
@@ -59,6 +55,10 @@ class EngineCoreRequest(
...
@@ -59,6 +55,10 @@ class EngineCoreRequest(
lora_request
:
Optional
[
LoRARequest
]
lora_request
:
Optional
[
LoRARequest
]
cache_salt
:
Optional
[
str
]
cache_salt
:
Optional
[
str
]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index
:
int
=
0
# Used in DP case to indicate which wave of requests this is expected to
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
# a wave finished notification is received.
...
...
vllm/v1/engine/async_llm.py
View file @
2dbe8c07
...
@@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
...
@@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.loggers
import
(
StatLoggerBase
,
StatLoggerFactory
,
from
vllm.v1.metrics.loggers
import
(
StatLoggerBase
,
StatLoggerFactory
,
setup_default_loggers
)
setup_default_loggers
)
from
vllm.v1.metrics.prometheus
import
shutdown_prometheus
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
...
@@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
log_requests
:
bool
=
True
,
log_requests
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
)
->
None
:
)
->
None
:
"""
"""
Create an AsyncLLM.
Create an AsyncLLM.
...
@@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
...
@@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
executor_class
=
executor_class
,
log_stats
=
self
.
log_stats
,
log_stats
=
self
.
log_stats
,
client_addresses
=
client_addresses
,
client_index
=
client_index
,
)
)
if
self
.
stat_loggers
:
if
self
.
stat_loggers
:
for
stat_logger
in
self
.
stat_loggers
[
0
]:
for
stat_logger
in
self
.
stat_loggers
[
0
]:
...
@@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
...
@@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
stat_loggers
:
Optional
[
list
[
StatLoggerFactory
]]
=
None
,
disable_log_requests
:
bool
=
False
,
disable_log_requests
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
disable_log_stats
:
bool
=
False
,
client_addresses
:
Optional
[
dict
[
str
,
str
]]
=
None
,
client_index
:
int
=
0
,
)
->
"AsyncLLM"
:
)
->
"AsyncLLM"
:
if
not
envs
.
VLLM_USE_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
...
@@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
log_requests
=
not
disable_log_requests
,
log_requests
=
not
disable_log_requests
,
log_stats
=
not
disable_log_stats
,
log_stats
=
not
disable_log_stats
,
usage_context
=
usage_context
,
usage_context
=
usage_context
,
client_addresses
=
client_addresses
,
client_index
=
client_index
,
)
)
@
classmethod
@
classmethod
...
@@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
...
@@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
def
shutdown
(
self
):
def
shutdown
(
self
):
"""Shutdown, cleaning up the background proc and IPC."""
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus
()
if
engine_core
:
=
getattr
(
self
,
"engine_core"
,
None
):
if
engine_core
:
=
getattr
(
self
,
"engine_core"
,
None
):
engine_core
.
shutdown
()
engine_core
.
shutdown
()
...
@@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
...
@@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
# background thread once Prometheus overhead is non-trivial.
if
stat_loggers
:
if
stat_loggers
:
assert
outputs
.
scheduler_stats
is
not
None
AsyncLLM
.
_record_stats
(
AsyncLLM
.
_record_stats
(
stat_loggers
[
outputs
.
engine_index
],
stat_loggers
[
outputs
.
engine_index
],
scheduler_stats
=
outputs
.
scheduler_stats
,
scheduler_stats
=
outputs
.
scheduler_stats
,
...
@@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
...
@@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
@
staticmethod
@
staticmethod
def
_record_stats
(
def
_record_stats
(
stat_loggers
:
list
[
StatLoggerBase
],
stat_loggers
:
list
[
StatLoggerBase
],
scheduler_stats
:
SchedulerStats
,
scheduler_stats
:
Optional
[
SchedulerStats
]
,
iteration_stats
:
Optional
[
IterationStats
],
iteration_stats
:
Optional
[
IterationStats
],
):
):
"""static so that it can be used from the output_handler task
"""static so that it can be used from the output_handler task
...
...
vllm/v1/engine/coordinator.py
0 → 100644
View file @
2dbe8c07
# SPDX-License-Identifier: Apache-2.0
import
multiprocessing
import
time
import
weakref
from
typing
import
Optional
import
msgspec.msgpack
import
zmq
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_mp_context
,
get_open_zmq_ipc_path
,
make_zmq_socket
from
vllm.v1.engine
import
EngineCoreOutputs
,
EngineCoreRequestType
from
vllm.v1.serial_utils
import
MsgpackDecoder
from
vllm.v1.utils
import
get_engine_client_zmq_addr
,
shutdown
logger
=
init_logger
(
__name__
)
class
DPCoordinator
:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def
__init__
(
self
,
parallel_config
:
ParallelConfig
):
# Assume coordinator is colocated with front-end procs.
front_publish_address
=
get_open_zmq_ipc_path
()
dp_size
=
parallel_config
.
data_parallel_size
assert
dp_size
>
1
,
"Coordinator only used for data parallel"
local_only
=
dp_size
==
parallel_config
.
data_parallel_size_local
host
=
parallel_config
.
data_parallel_master_ip
back_publish_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
back_output_address
=
get_engine_client_zmq_addr
(
local_only
,
host
)
context
=
get_mp_context
()
self
.
proc
:
multiprocessing
.
Process
=
context
.
Process
(
target
=
CoordinatorProc
.
run_coordinator
,
name
=
"VLLM_DP_Coordinator"
,
kwargs
=
{
"engine_count"
:
parallel_config
.
data_parallel_size
,
"front_publish_address"
:
front_publish_address
,
"back_output_address"
:
back_output_address
,
"back_publish_address"
:
back_publish_address
,
},
daemon
=
True
)
self
.
proc
.
start
()
self
.
stats_publish_address
=
front_publish_address
self
.
coord_in_address
=
back_publish_address
self
.
coord_out_address
=
back_output_address
self
.
_finalizer
=
weakref
.
finalize
(
self
,
shutdown
,
[
self
.
proc
])
def
get_stats_publish_address
(
self
)
->
str
:
return
self
.
stats_publish_address
def
get_engine_socket_addresses
(
self
)
->
tuple
[
str
,
str
]:
"""Returns tuple of ZMQ input address, output address."""
return
self
.
coord_in_address
,
self
.
coord_out_address
def
close
(
self
):
self
.
_finalizer
()
class
EngineState
:
def
__init__
(
self
):
self
.
request_counts
=
[
0
,
0
]
# [waiting, running]
class
CoordinatorProc
:
def
__init__
(
self
,
engine_count
:
int
):
self
.
ctx
=
zmq
.
Context
()
self
.
engines
=
[
EngineState
()
for
_
in
range
(
engine_count
)]
self
.
current_wave
=
0
self
.
engines_running
=
False
self
.
stats_changed
=
False
@
staticmethod
def
run_coordinator
(
engine_count
:
int
,
front_publish_address
:
str
,
back_output_address
:
str
,
back_publish_address
:
str
,
):
coordinator
=
CoordinatorProc
(
engine_count
=
engine_count
)
try
:
coordinator
.
process_input_socket
(
front_publish_address
,
back_output_address
,
back_publish_address
,
)
except
KeyboardInterrupt
:
logger
.
info
(
"DP Coordinator process exiting"
)
def
process_input_socket
(
self
,
front_publish_address
:
str
,
back_output_address
:
str
,
back_publish_address
:
str
):
decoder
=
MsgpackDecoder
(
EngineCoreOutputs
)
with
make_zmq_socket
(
path
=
front_publish_address
,
# IPC
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
XPUB
,
bind
=
True
,
)
as
publish_front
,
make_zmq_socket
(
path
=
back_output_address
,
# IPC or TCP
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
PULL
,
bind
=
True
,
)
as
output_back
,
make_zmq_socket
(
path
=
back_publish_address
,
# IPC or TCP
ctx
=
self
.
ctx
,
socket_type
=
zmq
.
XPUB
,
bind
=
True
,
)
as
publish_back
:
poller
=
zmq
.
Poller
()
poller
.
register
(
publish_front
,
zmq
.
POLLIN
)
poller
.
register
(
output_back
,
zmq
.
POLLIN
)
last_publish_time
=
0
while
True
:
elapsed
=
int
(
time
.
time
()
*
1000
)
-
last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for
=
100
if
self
.
stats_changed
else
3000
events
=
poller
.
poll
(
timeout
=
max
(
0
,
wait_for
-
elapsed
))
if
not
events
:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list
=
self
.
_get_engine_counts
()
to_publish
=
(
engine_req_counts_list
,
self
.
current_wave
,
self
.
engines_running
)
publish_front
.
send
(
msgspec
.
msgpack
.
encode
(
to_publish
))
last_publish_time
=
int
(
time
.
time
()
*
1000
)
self
.
stats_changed
=
False
continue
events
=
dict
(
events
)
if
publish_front
in
events
:
buffer
=
publish_front
.
recv
()
if
buffer
==
b
'
\x01
'
:
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude
,
wave
=
msgspec
.
msgpack
.
decode
(
buffer
)
if
wave
<
self
.
current_wave
:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude
=
None
if
not
self
.
engines_running
:
self
.
engines_running
=
True
self
.
stats_changed
=
True
self
.
_send_start_wave
(
publish_back
,
self
.
current_wave
,
engine_to_exclude
)
if
output_back
in
events
:
# We received a message from one of the engines.
buffer
=
output_back
.
recv
()
outputs
:
EngineCoreOutputs
=
decoder
.
decode
(
buffer
)
assert
not
outputs
.
outputs
assert
outputs
.
utility_output
is
None
eng_index
=
outputs
.
engine_index
if
outputs
.
scheduler_stats
:
# 1. Updated request load stats - update our local
# state with these.
stats
=
self
.
engines
[
eng_index
].
request_counts
stats
[
0
]
=
outputs
.
scheduler_stats
.
num_waiting_reqs
stats
[
1
]
=
outputs
.
scheduler_stats
.
num_running_reqs
self
.
stats_changed
=
True
if
(
wave
:
=
outputs
.
wave_complete
)
is
not
None
:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if
self
.
current_wave
<=
wave
:
logger
.
debug
(
"Moving DP wave from %d to %d."
,
self
.
current_wave
,
wave
)
self
.
current_wave
=
wave
+
1
self
.
engines_running
=
False
self
.
stats_changed
=
True
elif
(
wave
:
=
outputs
.
start_wave
)
is
not
None
and
(
wave
>
self
.
current_wave
or
(
wave
==
self
.
current_wave
and
not
self
.
engines_running
)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger
.
debug
(
"Starting wave %d after notification of "
"stale wave request from engine."
,
wave
)
self
.
current_wave
=
wave
self
.
engines_running
=
True
self
.
stats_changed
=
True
self
.
_send_start_wave
(
publish_back
,
wave
,
eng_index
)
@
staticmethod
def
_send_start_wave
(
socket
:
zmq
.
Socket
,
wave
:
int
,
exclude_engine_index
:
Optional
[
int
]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded
=
msgspec
.
msgpack
.
encode
((
wave
,
exclude_engine_index
))
socket
.
send_multipart
(
(
EngineCoreRequestType
.
START_DP_WAVE
.
value
,
wave_encoded
))
def
_get_engine_counts
(
self
)
->
list
[
list
[
int
]]:
"""Return list of [waiting, running] count lists for each engine."""
return
[
e
.
request_counts
for
e
in
self
.
engines
]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment