Unverified Commit 2dbe8c07 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Perf] API-server scaleout with many-to-many server-engine comms (#17546)

parent 84ec470f
...@@ -618,9 +618,11 @@ steps: ...@@ -618,9 +618,11 @@ steps:
- vllm/worker/model_runner.py - vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py - entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py - tests/v1/test_async_llm_dp.py
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- vllm/v1/engine/ - vllm/v1/engine/
commands: commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
......
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import socket
import threading
import time
from typing import Optional
from unittest.mock import patch
import pytest
from vllm.v1.utils import (APIServerProcessManager,
wait_for_completion_or_failure)
# Global variables to control worker behavior
WORKER_RUNTIME_SECONDS = 0.5
# Mock implementation of run_api_server_worker
def mock_run_api_server_worker(listen_address, sock, args, client_config=None):
"""Mock run_api_server_worker that runs for a specific time."""
print(f"Mock worker started with client_config: {client_config}")
time.sleep(WORKER_RUNTIME_SECONDS)
print("Mock worker completed successfully")
@pytest.fixture
def api_server_args():
"""Fixture to provide arguments for APIServerProcessManager."""
sock = socket.socket()
return {
"target_server_fn":
mock_run_api_server_worker,
"listen_address":
"localhost:8000",
"sock":
sock,
"args":
"test_args", # Simple string to avoid pickling issues
"num_servers":
3,
"input_addresses": [
"tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002",
"tcp://127.0.0.1:5003"
],
"output_addresses": [
"tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002",
"tcp://127.0.0.1:6003"
],
"stats_update_address":
"tcp://127.0.0.1:7000",
}
@pytest.mark.parametrize("with_stats_update", [True, False])
def test_api_server_process_manager_init(api_server_args, with_stats_update):
"""Test initializing the APIServerProcessManager."""
# Set the worker runtime to ensure tests complete in reasonable time
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.5
# Copy the args to avoid mutating the
args = api_server_args.copy()
if not with_stats_update:
args.pop("stats_update_address")
manager = APIServerProcessManager(**args)
try:
# Verify the manager was initialized correctly
assert len(manager.processes) == 3
# Verify all processes are running
for proc in manager.processes:
assert proc.is_alive()
print("Waiting for processes to run...")
time.sleep(WORKER_RUNTIME_SECONDS / 2)
# They should still be alive at this point
for proc in manager.processes:
assert proc.is_alive()
finally:
# Always clean up the processes
print("Cleaning up processes...")
manager.close()
# Give processes time to terminate
time.sleep(0.2)
# Verify all processes were terminated
for proc in manager.processes:
assert not proc.is_alive()
@patch("vllm.entrypoints.cli.serve.run_api_server_worker",
mock_run_api_server_worker)
def test_wait_for_completion_or_failure(api_server_args):
"""Test that wait_for_completion_or_failure works with failures."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 1.0
# Create the manager
manager = APIServerProcessManager(**api_server_args)
try:
assert len(manager.processes) == 3
# Create a result capture for the thread
result: dict[str, Optional[Exception]] = {"exception": None}
def run_with_exception_capture():
try:
wait_for_completion_or_failure(api_server_manager=manager)
except Exception as e:
result["exception"] = e
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture,
daemon=True)
wait_thread.start()
# Let all processes run for a short time
time.sleep(0.2)
# All processes should still be running
assert all(proc.is_alive() for proc in manager.processes)
# Now simulate a process failure
print("Simulating process failure...")
manager.processes[0].terminate()
# Wait for the wait_for_completion_or_failure
# to detect and handle the failure
# This should trigger it to terminate all other processes
wait_thread.join(timeout=1.0)
# The wait thread should have exited
assert not wait_thread.is_alive()
# Verify that an exception was raised with appropriate error message
assert result["exception"] is not None
assert "died with exit code" in str(result["exception"])
# All processes should now be terminated
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(), f"Process {i} should not be alive"
finally:
manager.close()
time.sleep(0.2)
@pytest.mark.timeout(30)
def test_normal_completion(api_server_args):
"""Test that wait_for_completion_or_failure works in normal completion."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.1
# Create the manager
manager = APIServerProcessManager(**api_server_args)
try:
# Give processes time to terminate
# wait for processes to complete
remaining_processes = manager.processes.copy()
while remaining_processes:
for proc in remaining_processes:
if not proc.is_alive():
remaining_processes.remove(proc)
time.sleep(0.1)
# Verify all processes have terminated
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(
), f"Process {i} still alive after terminate()"
# Now call wait_for_completion_or_failure
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure(api_server_manager=manager)
finally:
# Clean up just in case
manager.close()
time.sleep(0.2)
@pytest.mark.timeout(30)
def test_external_process_monitoring(api_server_args):
"""Test that wait_for_completion_or_failure handles additional processes."""
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 100
# Create and start the external process
# (simulates local_engine_manager or coordinator)
spawn_context = multiprocessing.get_context("spawn")
external_proc = spawn_context.Process(target=mock_run_api_server_worker,
name="MockExternalProcess")
external_proc.start()
# Create the class to simulate a coordinator
class MockCoordinator:
def __init__(self, proc):
self.proc = proc
def close(self):
if self.proc.is_alive():
self.proc.terminate()
self.proc.join(timeout=0.5)
# Create a mock coordinator with the external process
mock_coordinator = MockCoordinator(external_proc)
# Create the API server manager
manager = APIServerProcessManager(**api_server_args)
try:
# Verify manager initialization
assert len(manager.processes) == 3
# Create a result capture for the thread
result: dict[str, Optional[Exception]] = {"exception": None}
def run_with_exception_capture():
try:
wait_for_completion_or_failure(api_server_manager=manager,
coordinator=mock_coordinator)
except Exception as e:
result["exception"] = e
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture,
daemon=True)
wait_thread.start()
# Terminate the external process to trigger a failure
time.sleep(0.2)
external_proc.terminate()
# Wait for the thread to detect the failure
wait_thread.join(timeout=1.0)
# The wait thread should have completed
assert not wait_thread.is_alive(
), "wait_for_completion_or_failure thread still running"
# Verify that an exception was raised with appropriate error message
assert result["exception"] is not None, "No exception was raised"
error_message = str(result["exception"])
assert "died with exit code" in error_message, \
f"Unexpected error message: {error_message}"
assert "MockExternalProcess" in error_message, \
f"Error doesn't mention external process: {error_message}"
# Verify that all API server processes were terminated as a result
for i, proc in enumerate(manager.processes):
assert not proc.is_alive(
), f"API server process {i} was not terminated"
finally:
# Clean up
manager.close()
mock_coordinator.close()
time.sleep(0.2)
...@@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs ...@@ -28,7 +28,7 @@ from tests.models.utils import TextTextLogprobs
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
...@@ -99,7 +99,8 @@ class RemoteOpenAIServer: ...@@ -99,7 +99,8 @@ class RemoteOpenAIServer:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.") description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser) subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args]) args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
......
...@@ -45,7 +45,6 @@ def make_request(request_id, ...@@ -45,7 +45,6 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100, eos_token_id=100,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
) )
......
...@@ -38,7 +38,6 @@ def make_request(request_id, ...@@ -38,7 +38,6 @@ def make_request(request_id,
sampling_params=SamplingParams(max_tokens=17, sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs), prompt_logprobs=prompt_logprobs),
eos_token_id=100, eos_token_id=100,
arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
) )
......
...@@ -138,7 +138,6 @@ def create_requests(num_requests: int, ...@@ -138,7 +138,6 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
) )
requests.append(request) requests.append(request)
return requests return requests
...@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -744,7 +743,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i]) assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet # No draft or accepted tokens counted yet
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is None assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
# Schedule the speculated tokens for validation # Schedule the speculated tokens for validation
output = scheduler.schedule() output = scheduler.schedule()
...@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -772,7 +772,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
scheduler_stats = engine_core_outputs.scheduler_stats scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0: if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None assert scheduler_stats.spec_decoding_stats is None
else: else:
...@@ -843,7 +844,7 @@ def _step_until_done( ...@@ -843,7 +844,7 @@ def _step_until_done(
# We should be in the decode phase now. # We should be in the decode phase now.
assert num_scheduled_tokens == 1 assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0 assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output) ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True all_done = True
for eco in ecos.outputs: for eco in ecos.outputs:
if eco.finish_reason is None: if eco.finish_reason is None:
......
...@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -88,7 +88,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.running) == 4 assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step()[0].outputs) > 0: while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
...@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -163,11 +163,11 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
req0.request_id = req1.request_id = "test" req0.request_id = req1.request_id = "test"
engine_core.add_request(req0) engine_core.add_request(req0)
while len(engine_core.step()[0].outputs) > 0: while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass pass
engine_core.add_request(req1) engine_core.add_request(req1)
while len(engine_core.step()[0].outputs) > 0: while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
...@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): ...@@ -207,7 +207,7 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done. # Loop through until they are all done.
while len(engine_core.step()[0].outputs) > 0: while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass pass
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
...@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -327,7 +327,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[1] == 4 assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0. # Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()[0] output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
...@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -339,7 +339,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert scheduler_output.num_scheduled_tokens[0] == 1 assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1. # Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()[0] output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
...@@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -362,7 +362,7 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
if step % 2 == 0: if step % 2 == 0:
# Even steps consumes an output. # Even steps consumes an output.
assert output is not None assert output is not None
assert len(output.outputs) == 1 assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests: if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[ assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id] req_id].num_tokens == expected_num_tokens[req_id]
......
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
"--api-server-count",
"4",
"--data_parallel_size",
DP_SIZE,
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(client: openai.AsyncOpenAI,
model_name: str) -> None:
async def make_request():
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request
result = await make_request()
assert result is not None
await asyncio.sleep(0.5)
# Send two bursts of requests
num_requests = 100
tasks = [make_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
tasks = [make_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(results) == num_requests
assert all(completion is not None for completion in results)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request():
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request
result = await make_streaming_request()
assert result is not None
await asyncio.sleep(0.5)
# Send two bursts of requests
num_requests = 100
tasks = [make_streaming_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
tasks = [make_streaming_request() for _ in range(num_requests)]
results = await asyncio.gather(*tasks)
assert len(
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."
...@@ -43,7 +43,7 @@ def test_basic_lifecycle(): ...@@ -43,7 +43,7 @@ def test_basic_lifecycle():
# Ensure the request is finished after 1 tokens. # Ensure the request is finished after 1 tokens.
assert request.is_finished() assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs.outputs[0] output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None assert output.kv_transfer_params is not None
...@@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle(): ...@@ -165,7 +165,7 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote]) model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output) eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco.outputs[0].kv_transfer_params kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit. # Ensure we send all block ids, even if there is a cache hit.
assert (len( assert (len(
......
...@@ -61,7 +61,7 @@ def test_basic_lifecycle(): ...@@ -61,7 +61,7 @@ def test_basic_lifecycle():
# (1c): update_from_output() # (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output, engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output) model_runner_output)
assert len(engine_core_outputs.outputs) == 0 assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2): # STEP (2):
# (2a): schedule(): nothing happens! # (2a): schedule(): nothing happens!
...@@ -112,7 +112,7 @@ def test_basic_lifecycle(): ...@@ -112,7 +112,7 @@ def test_basic_lifecycle():
model_runner_output) model_runner_output)
scheduler.schedule() scheduler.schedule()
outputs = engine_core_outputs.outputs outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1 assert len(outputs) == 1
output = outputs[0] output = outputs[0]
assert output.finish_reason == FinishReason.STOP assert output.finish_reason == FinishReason.STOP
...@@ -335,7 +335,7 @@ def test_full_block_prompt(): ...@@ -335,7 +335,7 @@ def test_full_block_prompt():
model_runner_output) model_runner_output)
scheduler.schedule() scheduler.schedule()
outputs = engine_core_outputs.outputs outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1 assert len(outputs) == 1
output = outputs[0] output = outputs[0]
assert output.finish_reason == FinishReason.STOP assert output.finish_reason == FinishReason.STOP
......
...@@ -153,7 +153,6 @@ def create_request( ...@@ -153,7 +153,6 @@ def create_request(
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
) )
req.kv_transfer_params = kv_transfer_params req.kv_transfer_params = kv_transfer_params
return req return req
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import os
import signal import signal
import sys
import uvloop import uvloop
import zmq
import vllm.envs as envs import vllm.envs as envs
from vllm import AsyncEngineArgs from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server)
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args) validate_parsed_serve_args)
from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG, from vllm.entrypoints.utils import (VLLM_SERVE_PARSER_EPILOG,
show_filtered_argument_or_group_from_help) show_filtered_argument_or_group_from_help)
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
EngineZmqAddresses, get_engine_client_zmq_addr,
wait_for_completion_or_failure,
wait_for_engine_startup)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand): ...@@ -36,9 +47,12 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, 'model_tag') and args.model_tag is not None: if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag args.model = args.model_tag
if args.headless: if args.headless or args.api_server_count < 1:
run_headless(args) run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else: else:
# Single API server (this process).
uvloop.run(run_server(args)) uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None: def validate(self, args: argparse.Namespace) -> None:
...@@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand): ...@@ -69,6 +83,11 @@ class ServeSubcommand(CLISubcommand):
type=int, type=int,
default=0, default=0,
help='Starting data parallel rank for secondary nodes.') help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument('--api-server-count',
'-asc',
type=int,
default=1,
help='How many API server processes to run.')
serve_parser.add_argument( serve_parser.add_argument(
"--config", "--config",
type=str, type=str,
...@@ -91,22 +110,25 @@ def cmd_init() -> list[CLISubcommand]: ...@@ -91,22 +110,25 @@ def cmd_init() -> list[CLISubcommand]:
def run_headless(args: argparse.Namespace): def run_headless(args: argparse.Namespace):
if args.api_server_count > 1:
raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig. # Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1") raise ValueError("Headless mode is only supported for V1")
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port) handshake_address = get_tcp_uri(host, port)
if local_engine_count <= 0: if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in " raise ValueError("data_parallel_size_local must be > 0 in "
"headless mode") "headless mode")
# Catch SIGTERM and SIGINT to allow graceful shutdown. # Catch SIGTERM and SIGINT to allow graceful shutdown.
...@@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace): ...@@ -119,7 +141,7 @@ def run_headless(args: argparse.Namespace):
logger.info( logger.info(
"Launching %d data parallel engine(s) in headless mode, " "Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address) "with head node address %s.", local_engine_count, handshake_address)
# Create the engines. # Create the engines.
engine_manager = CoreEngineProcManager( engine_manager = CoreEngineProcManager(
...@@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace): ...@@ -129,7 +151,7 @@ def run_headless(args: argparse.Namespace):
local_start_index=0, local_start_index=0,
vllm_config=vllm_config, vllm_config=vllm_config,
on_head_node=False, on_head_node=False,
input_address=input_address, handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config), executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
) )
...@@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace): ...@@ -139,3 +161,142 @@ def run_headless(args: argparse.Namespace):
finally: finally:
logger.info("Shutting down.") logger.info("Shutting down.")
engine_manager.close() engine_manager.close()
def run_multi_api_server(args: argparse.Namespace):
assert not args.headless
num_api_servers = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1:
setup_multiprocess_prometheus()
listen_address, sock = setup_server(args)
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
model_config = vllm_config.model_config
if num_api_servers > 1:
if not envs.VLLM_USE_V1:
raise ValueError("api_server_count > 1 is only supported for V1")
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1")
if model_config.is_multimodal_model and not (
model_config.disable_mm_preprocessor_cache):
logger.warning(
"Multi-model preprocessor cache will be disabled for"
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
parallel_config = vllm_config.parallel_config
assert parallel_config.data_parallel_rank == 0
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
local_only = local_engine_count == dp_size
# Set up input and output addresses.
input_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
output_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
addresses = EngineZmqAddresses(
inputs=input_addresses,
outputs=output_addresses,
)
# Set up coordinator for dp > 1.
coordinator = None
stats_update_address = None
if dp_size > 1:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
stats_update_address = coordinator.get_stats_publish_address()
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
# Start local engines.
if not local_engine_count:
local_engine_manager = None
else:
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=0,
local_start_index=0)
# Start API servers using the manager
api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=input_addresses,
output_addresses=output_addresses,
stats_update_address=stats_update_address)
# Wait for engine handshakes to complete.
core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
wait_for_engine_startup(
handshake_socket,
addresses,
core_engines,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
local_engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Entrypoint for individual API server worker processes."""
# Add process-specific prefix to stdout and stderr.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
uvloop.run(
run_server_worker(listen_address, sock, args, client_config,
**uvicorn_kwargs))
...@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager ...@@ -17,7 +17,7 @@ from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
from json import JSONDecodeError from json import JSONDecodeError
from typing import Annotated, Optional from typing import Annotated, Any, Optional
import prometheus_client import prometheus_client
import regex as re import regex as re
...@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request ...@@ -26,6 +26,8 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
...@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer ...@@ -97,6 +99,7 @@ from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit) is_valid_ipv6_address, set_ulimit)
from vllm.v1.metrics.prometheus import get_prometheus_registry
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
...@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI): ...@@ -142,14 +145,17 @@ async def lifespan(app: FastAPI):
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client( async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]: args: Namespace,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle # Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args( async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine: engine_args, args.disable_frontend_multiprocessing,
client_config) as engine:
yield engine yield engine
...@@ -157,6 +163,7 @@ async def build_async_engine_client( ...@@ -157,6 +163,7 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args( async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False, disable_frontend_multiprocessing: bool = False,
client_config: Optional[dict[str, Any]] = None,
) -> AsyncIterator[EngineClient]: ) -> AsyncIterator[EngineClient]:
""" """
Create EngineClient, either: Create EngineClient, either:
...@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args( ...@@ -179,12 +186,16 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None async_llm: Optional[AsyncLLM] = None
client_index = client_config.pop(
"client_index") if client_config else 0
try: try:
async_llm = AsyncLLM.from_vllm_config( async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config, vllm_config=vllm_config,
usage_context=usage_context, usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests, disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats) disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_index=client_index)
# Don't keep the dummy data in memory # Don't keep the dummy data in memory
await async_llm.reset_mm_cache() await async_llm.reset_mm_cache()
...@@ -318,22 +329,9 @@ class PrometheusResponse(Response): ...@@ -318,22 +329,9 @@ class PrometheusResponse(Response):
def mount_metrics(app: FastAPI): def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing. """Mount prometheus metrics to a FastAPI app."""
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (REGISTRY, CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
registry = REGISTRY
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) registry = get_prometheus_registry()
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# `response_class=PrometheusResponse` is needed to return an HTTP response # `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
...@@ -1256,13 +1254,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: ...@@ -1256,13 +1254,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket:
return sock return sock
async def run_server(args, **uvicorn_kwargs) -> None: def validate_api_server_args(args):
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valid_tool_parses = ToolParserManager.tool_parsers.keys() valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \ if args.enable_auto_tool_choice \
and args.tool_call_parser not in valid_tool_parses: and args.tool_call_parser not in valid_tool_parses:
...@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -1276,6 +1268,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
f"invalid reasoning parser: {args.reasoning_parser} " f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})") f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger.info("vLLM API server version %s", VLLM_VERSION)
log_non_default_args(args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up. # workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray. # This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204 # see https://github.com/vllm-project/vllm/issues/8204
...@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None: ...@@ -1292,22 +1297,41 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client: addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(
addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
async def run_server_worker(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
async with build_async_engine_client(args, client_config) as engine_client:
app = build_app(args) app = build_app(args)
vllm_config = await engine_client.get_vllm_config() vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args) await init_app_state(engine_client, vllm_config, app.state, args)
def _listen_addr(a: str) -> str: logger.info("Starting vLLM API server %d on %s", server_index,
if is_valid_ipv6_address(a): listen_address)
return '[' + a + ']'
return a or "0.0.0.0"
is_ssl = args.ssl_keyfile and args.ssl_certfile
logger.info("Starting vLLM API server on http%s://%s:%d",
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
sock=sock, sock=sock,
......
...@@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): ...@@ -229,6 +229,11 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
self.add_adapter(lora) self.add_adapter(lora)
def add_adapter(self, lora_request: LoRARequest) -> bool: def add_adapter(self, lora_request: LoRARequest) -> bool:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if lora_request.lora_int_id not in self.list_adapters(): if lora_request.lora_int_id not in self.list_adapters():
# Load the new adapter first to ensure it is actually valid, before # Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters. # evicting any existing adapters.
......
...@@ -2420,6 +2420,7 @@ def make_zmq_socket( ...@@ -2420,6 +2420,7 @@ def make_zmq_socket(
socket_type: Any, socket_type: Any,
bind: Optional[bool] = None, bind: Optional[bool] = None,
identity: Optional[bytes] = None, identity: Optional[bytes] = None,
linger: Optional[int] = None,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics.""" """Make a ZMQ socket with the proper bind/connect semantics."""
...@@ -2439,7 +2440,7 @@ def make_zmq_socket( ...@@ -2439,7 +2440,7 @@ def make_zmq_socket(
buf_size = -1 # Use system default buffer size buf_size = -1 # Use system default buffer size
if bind is None: if bind is None:
bind = socket_type != zmq.PUSH bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVHWM, 0)
...@@ -2452,6 +2453,9 @@ def make_zmq_socket( ...@@ -2452,6 +2453,9 @@ def make_zmq_socket(
if identity is not None: if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity) socket.setsockopt(zmq.IDENTITY, identity)
if linger is not None:
socket.setsockopt(zmq.LINGER, linger)
# Determine if the path is a TCP socket with an IPv6 address. # Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so. # Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path) scheme, host, _ = split_zmq_path(path)
......
...@@ -45,7 +45,7 @@ class SchedulerInterface(ABC): ...@@ -45,7 +45,7 @@ class SchedulerInterface(ABC):
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput", model_runner_output: "ModelRunnerOutput",
) -> "EngineCoreOutputs": ) -> dict[int, "EngineCoreOutputs"]:
"""Update the scheduler state based on the model runner output. """Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled This method is called after the model runner has processed the scheduled
...@@ -55,7 +55,8 @@ class SchedulerInterface(ABC): ...@@ -55,7 +55,8 @@ class SchedulerInterface(ABC):
for each request. for each request.
Returns: Returns:
A EngineCoreOutputs object containing the outputs for each request. A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -126,6 +127,11 @@ class SchedulerInterface(ABC): ...@@ -126,6 +127,11 @@ class SchedulerInterface(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise NotImplementedError
@abstractmethod @abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]: def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging. """Make a SchedulerStats object for logging.
......
...@@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface): ...@@ -58,7 +58,8 @@ class Scheduler(SchedulerInterface):
# request ids should be included in the EngineCoreOutputs returned # request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine # by update_from_outputs(). This is currently used in the multi-engine
# case to track request lifetimes efficiently. # case to track request lifetimes efficiently.
self.include_finished_set = include_finished_set self.finished_req_ids_dict: Optional[dict[int, set[str]]] = (
defaultdict(set) if include_finished_set else None)
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
...@@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface): ...@@ -693,7 +694,7 @@ class Scheduler(SchedulerInterface):
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs: ) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
...@@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface): ...@@ -701,7 +702,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = [] new_running: list[Request] = []
outputs: list[EngineCoreOutput] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
...@@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface): ...@@ -797,7 +798,7 @@ class Scheduler(SchedulerInterface):
if new_token_ids or kv_transfer_params: if new_token_ids or kv_transfer_params:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
outputs.append( outputs[request.client_index].append(
EngineCoreOutput( EngineCoreOutput(
request_id=req_id, request_id=req_id,
new_token_ids=new_token_ids, new_token_ids=new_token_ids,
...@@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface): ...@@ -828,17 +829,38 @@ class Scheduler(SchedulerInterface):
self._cached_reqs_data[req_data.req_id].append(req_data) self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs, # Create EngineCoreOutputs for all clients that have requests with
scheduler_stats=self.make_stats(spec_decoding_stats), # outputs in this step.
) engine_core_outputs = {
if self.include_finished_set: client_index: EngineCoreOutputs(outputs=outs)
#TODO currently sending duplicates here, improve this for client_index, outs in outputs.items()
engine_core_outputs.finished_requests = ( }
scheduler_output.finished_req_ids | self.finished_req_ids)
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids is not None:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():
# Set finished request set in EngineCoreOutputs for this client.
if (eco := engine_core_outputs.get(client_index)) is not None:
eco.finished_requests = finished_set
else:
engine_core_outputs[client_index] = EngineCoreOutputs(
finished_requests=finished_set)
finished_req_ids.clear()
if engine_core_outputs:
# Return stats to only one of the front-ends.
next(iter(engine_core_outputs.values())).scheduler_stats = (
self.make_stats(spec_decoding_stats))
return engine_core_outputs return engine_core_outputs
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
self.waiting.append(request) self.waiting.append(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
...@@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface): ...@@ -880,8 +902,11 @@ class Scheduler(SchedulerInterface):
delay_free_blocks, kv_xfer_params = self._connector_finished(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
self._cached_reqs_data.pop(request.request_id, None) request_id = request.request_id
self.finished_req_ids.add(request.request_id) self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
if not delay_free_blocks: if not delay_free_blocks:
self._free_blocks(request) self._free_blocks(request)
......
...@@ -44,10 +44,6 @@ class EngineCoreRequest( ...@@ -44,10 +44,6 @@ class EngineCoreRequest(
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg] gc=False): # type: ignore[call-arg]
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
# due to circular imports and typing we have in data.py
request_id: str request_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
...@@ -59,6 +55,10 @@ class EngineCoreRequest( ...@@ -59,6 +55,10 @@ class EngineCoreRequest(
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
cache_salt: Optional[str] cache_salt: Optional[str]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
client_index: int = 0
# Used in DP case to indicate which wave of requests this is expected to # Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before # belong to, to cover a race condition where the request is sent before
# a wave finished notification is received. # a wave finished notification is received.
......
...@@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor ...@@ -36,6 +36,7 @@ from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers) setup_default_loggers)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,6 +55,8 @@ class AsyncLLM(EngineClient): ...@@ -54,6 +55,8 @@ class AsyncLLM(EngineClient):
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None: ) -> None:
""" """
Create an AsyncLLM. Create an AsyncLLM.
...@@ -124,6 +127,8 @@ class AsyncLLM(EngineClient): ...@@ -124,6 +127,8 @@ class AsyncLLM(EngineClient):
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats, log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
) )
if self.stat_loggers: if self.stat_loggers:
for stat_logger in self.stat_loggers[0]: for stat_logger in self.stat_loggers[0]:
...@@ -145,6 +150,8 @@ class AsyncLLM(EngineClient): ...@@ -145,6 +150,8 @@ class AsyncLLM(EngineClient):
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
disable_log_requests: bool = False, disable_log_requests: bool = False,
disable_log_stats: bool = False, disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "AsyncLLM": ) -> "AsyncLLM":
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError( raise ValueError(
...@@ -162,6 +169,8 @@ class AsyncLLM(EngineClient): ...@@ -162,6 +169,8 @@ class AsyncLLM(EngineClient):
log_requests=not disable_log_requests, log_requests=not disable_log_requests,
log_stats=not disable_log_stats, log_stats=not disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
client_addresses=client_addresses,
client_index=client_index,
) )
@classmethod @classmethod
...@@ -195,6 +204,8 @@ class AsyncLLM(EngineClient): ...@@ -195,6 +204,8 @@ class AsyncLLM(EngineClient):
def shutdown(self): def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC.""" """Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if engine_core := getattr(self, "engine_core", None): if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown() engine_core.shutdown()
...@@ -398,7 +409,6 @@ class AsyncLLM(EngineClient): ...@@ -398,7 +409,6 @@ class AsyncLLM(EngineClient):
# TODO(rob): make into a coroutine and launch it in # TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial. # background thread once Prometheus overhead is non-trivial.
if stat_loggers: if stat_loggers:
assert outputs.scheduler_stats is not None
AsyncLLM._record_stats( AsyncLLM._record_stats(
stat_loggers[outputs.engine_index], stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
...@@ -422,7 +432,7 @@ class AsyncLLM(EngineClient): ...@@ -422,7 +432,7 @@ class AsyncLLM(EngineClient):
@staticmethod @staticmethod
def _record_stats( def _record_stats(
stat_loggers: list[StatLoggerBase], stat_loggers: list[StatLoggerBase],
scheduler_stats: SchedulerStats, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
): ):
"""static so that it can be used from the output_handler task """static so that it can be used from the output_handler task
......
# SPDX-License-Identifier: Apache-2.0
import multiprocessing
import time
import weakref
from typing import Optional
import msgspec.msgpack
import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
logger = init_logger(__name__)
class DPCoordinator:
"""Coordinator process used for data-parallel deployments (DP>1).
Intermediates between multiple DP engine rank processes and one or more
front-end API server processes.
* Collects stats from each DP engine (currently just waiting and running
queue lengths), and publishes these to all front-ends for use in
load-balancing decisions.
* Keeps track of the current DP "request wave" number and running state
of the engines. This is received from the DP rank 0 engine and published
to the front-end processes along with the current load stats.
The engines alternate between a global running/paused state. The global
"request wave" number is a count of the number of times that the workers
collectively move from a running state to a paused state. This transition
is synchronized via the all-reduce operation performed in the
DPEngineCoreProc._has_global_unfinished_reqs method.
* Broadcasts the START_DP_WAVE message to engines to move them from paused
to running state when one engine receives a new request. This can happen
in two cases:
1) A front-end sending a new request while the engines are paused will
concurrently notify the coordinator.
2) An engine receiving a request for a stale request wave while in paused
state will notify the coordinator.
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
"""
def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host)
back_output_address = get_engine_client_zmq_addr(local_only, host)
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
target=CoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
kwargs={
"engine_count": parallel_config.data_parallel_size,
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
},
daemon=True)
self.proc.start()
self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
self.coord_out_address = back_output_address
self._finalizer = weakref.finalize(self, shutdown, [self.proc])
def get_stats_publish_address(self) -> str:
return self.stats_publish_address
def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
class EngineState:
def __init__(self):
self.request_counts = [0, 0] # [waiting, running]
class CoordinatorProc:
def __init__(self, engine_count: int):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod
def run_coordinator(
engine_count: int,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
):
coordinator = CoordinatorProc(engine_count=engine_count)
try:
coordinator.process_input_socket(
front_publish_address,
back_output_address,
back_publish_address,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
def process_input_socket(self, front_publish_address: str,
back_output_address: str,
back_publish_address: str):
decoder = MsgpackDecoder(EngineCoreOutputs)
with make_zmq_socket(
path=front_publish_address, # IPC
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_front, make_zmq_socket(
path=back_output_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.PULL,
bind=True,
) as output_back, make_zmq_socket(
path=back_publish_address, # IPC or TCP
ctx=self.ctx,
socket_type=zmq.XPUB,
bind=True,
) as publish_back:
poller = zmq.Poller()
poller.register(publish_front, zmq.POLLIN)
poller.register(output_back, zmq.POLLIN)
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for = 100 if self.stats_changed else 3000
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts()
to_publish = (engine_req_counts_list, self.current_wave,
self.engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue
events = dict(events)
if publish_front in events:
buffer = publish_front.recv()
if buffer == b'\x01':
# Ignore subscription messages.
continue
# We received a message on the front-end XPUB socket,
# from an API server sending a new request while the
# engines are paused, so that we can wake the other
# engines.
engine_to_exclude, wave = msgspec.msgpack.decode(buffer)
if wave < self.current_wave:
# If the wave number is stale, ensure the message is
# handled by all the engines.
engine_to_exclude = None
if not self.engines_running:
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, self.current_wave,
engine_to_exclude)
if output_back in events:
# We received a message from one of the engines.
buffer = output_back.recv()
outputs: EngineCoreOutputs = decoder.decode(buffer)
assert not outputs.outputs
assert outputs.utility_output is None
eng_index = outputs.engine_index
if outputs.scheduler_stats:
# 1. Updated request load stats - update our local
# state with these.
stats = self.engines[eng_index].request_counts
stats[0] = outputs.scheduler_stats.num_waiting_reqs
stats[1] = outputs.scheduler_stats.num_running_reqs
self.stats_changed = True
if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've
# moved into the global paused state
# (engines_running==False)
if self.current_wave <= wave:
logger.debug("Moving DP wave from %d to %d.",
self.current_wave, wave)
self.current_wave = wave + 1
self.engines_running = False
self.stats_changed = True
elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or
(wave == self.current_wave
and not self.engines_running)):
# 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the
# next wave (race condition handling).
logger.debug(
"Starting wave %d after notification of "
"stale wave request from engine.", wave)
self.current_wave = wave
self.engines_running = True
self.stats_changed = True
self._send_start_wave(publish_back, wave, eng_index)
@staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]):
"""Broadcast the START_DP_WAVE message to all the engines.
It includes the current wave number and index of engine which
has already received a request with this wave number and so doesn't
require additional notification.
"""
wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index))
socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine."""
return [e.request_counts for e in self.engines]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment