Unverified Commit 7c7714d8 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[Core][Bugfix][Perf] Introduce `MQLLMEngine` to avoid `asyncio` OH (#8157)


Co-authored-by: default avatarNick Hill <nickhill@us.ibm.com>
Co-authored-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
parent 9d104b5b
......@@ -43,13 +43,15 @@ steps:
fast_check: true
source_file_dependencies:
- vllm/
- tests/mq_llm_engine
- tests/async_engine
- tests/test_inputs
- tests/multimodal
- tests/test_utils
- tests/worker
commands:
- pytest -v -s async_engine # Async Engine
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
- pytest -v -s multimodal
......
......@@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_TIMEOUT=1800000``
Example commands and usage:
===========================
......
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
@pytest.fixture(scope="module")
def server():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--enforce-eager",
"--chat-template",
str(chatml_jinja_path),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
@pytest.mark.asyncio
async def test_single_completion(client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 5
@pytest.mark.asyncio
async def test_single_chat_session(client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# test single completion
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=55, total_tokens=65)
message = choice.message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid
import pytest
import pytest_asyncio
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()
def dummy_engine_builder(*args, **kwargs):
return dummy_engine
with monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
try:
yield server
finally:
server_task.cancel()
server.cleanup()
@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()
try:
yield client
finally:
client.close()
@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# The client should suppress timeouts on `abort`s
# and return normally, assuming the server will eventually
# abort the request.
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id"))
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")
def raiser():
raise exception
m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)
client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
client.close()
# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass
# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()
......@@ -18,38 +18,32 @@ TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "4096", "--enable-chunked-prefill",
"--disable-log-requests", "--enforce-eager"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_data(server):
return {
"url": f"{server.url_for('v1')}/completions",
}
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
def test_lm_eval_accuracy(more_args):
args = list(DEFAULT_ARGS)
args.extend(more_args)
print(f"Running with: {args}")
def test_lm_eval_accuracy(server_data):
model_args = (f"model={MODEL_NAME},"
f"base_url={server_data['url']},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
url = f"{remote_server.url_for('v1')}/completions"
model_args = (
f"model={MODEL_NAME},"
f"base_url={url},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
......@@ -5,7 +5,7 @@ from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ..utils import VLLM_PATH
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
......
import time
import pytest
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
@pytest.mark.asyncio
async def test_mp_crash_detection():
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# use an invalid tensor_parallel_size to trigger the
# error in the server
args.tensor_parallel_size = 65536
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass
......@@ -4,7 +4,7 @@ from dataclasses import dataclass
from unittest.mock import MagicMock
from vllm.config import MultiModalConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer
......@@ -52,8 +52,9 @@ def test_async_serving_chat_init():
def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
......
......@@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
......@@ -18,7 +18,7 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=AsyncEngineClient)
mock_engine_client = MagicMock(spec=EngineClient)
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048
......
......@@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path):
prompt="Hello, my name is")
# Now the server should shut down
return_code = remote_server.proc.wait(timeout=3)
return_code = remote_server.proc.wait(timeout=8)
assert return_code is not None
"""Test that aborting is handled properly."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_id_to_be_aborted = "request-aborted"
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
# Requests started before one to be aborted.
tasks = []
for request_id in request_ids_a:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Aborted.
task_aborted = asyncio.create_task(
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
# Requests started after one to be aborted.
for request_id in request_ids_b:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Actually abort.
await asyncio.sleep(0.5)
await client.abort(request_id_to_be_aborted)
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
count, request_id = await task
assert count == EXPECTED_TOKENS, (
f"{request_id} generated only {count} tokens")
# Cancel task (this will hang indefinitely if not).
task_aborted.cancel()
# Shutdown.
client.close()
"""Test that various errors are handled properly."""
import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.execute_model = Mock(
side_effect=RAISED_ERROR(RAISED_VALUE))
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_forward) as engine:
client = await engine.make_client()
# Server should be healthy after initial probe.
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
await asyncio.sleep(1.0)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Shutdown.
client.close()
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
with RemoteMQLLMEngine(
engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_model_executor_health) as engine:
client = await engine.make_client()
assert client.is_running
# Health probe should throw RAISED_ERROR.
await asyncio.sleep(15.)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
client.close()
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during abort call.
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# Firsh check health should work.
await client.check_health()
# Trigger an abort on the client side.
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")
# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())
# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
await abort_task
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
"invalid-lora", 1,
"invalid-path")):
pass
# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
# Shutdown.
client.close()
@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash.
def mock_init():
raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
NUM_REQUESTS = 10000
# Scenarios to test for num generated token.
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_load(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(client, request_id, NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
failed_request_id = None
tokens = None
for task in tasks:
num_generated_tokens, request_id = await task
if (num_generated_tokens != NUM_EXPECTED_TOKENS
and failed_request_id is None):
failed_request_id = request_id
tokens = num_generated_tokens
assert failed_request_id is None, (
f"{failed_request_id} generated {tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
# Shutdown.
client.close()
import asyncio
import multiprocessing
from typing import Callable, Tuple, Union
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
async def generate(
client: MQLLMEngineClient,
request_id: str,
num_tokens: int,
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
final_output = None
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):
count += 1
final_output = out
await asyncio.sleep(0.)
if return_output:
return final_output
# Confirm we generated all the tokens we expected.
return count, request_id
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Run engine.
engine.start()
class RemoteMQLLMEngine:
def __init__(self,
engine_args: AsyncEngineArgs,
ipc_path: str,
run_fn: Callable = run_normal) -> None:
self.engine_args = engine_args
self.ipc_path = ipc_path
context = multiprocessing.get_context("spawn")
self.proc = context.Process(target=run_fn,
args=(engine_args, ipc_path))
self.proc.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.kill()
async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config)
while True:
try:
await client.setup()
break
except TimeoutError:
assert self.proc.is_alive()
return client
import os
from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
......
......@@ -119,7 +119,7 @@ class RemoteOpenAIServer:
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
try:
self.proc.wait(3)
self.proc.wait(8)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()
......
......@@ -601,9 +601,12 @@ class AsyncLLMEngine:
return self._errored_with is not None
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
return None
def dead_error(self) -> BaseException:
return AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
......
......@@ -1289,6 +1289,7 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
......
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from typing import List, Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM = 0
class MQEngineDeadError(RuntimeError):
pass
@dataclass
......@@ -27,24 +30,44 @@ class RPCGenerateRequest:
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCError:
request_id: Optional[str]
is_engine_errored: bool
exception: BaseException
@dataclass
class RPCAbortRequest:
request_id: str
class RPCUtilityRequest(Enum):
class RPCHealthRequest:
pass
class RPCStartupRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
def ENGINE_DEAD_ERROR(
error: Optional[BaseException] = None) -> MQEngineDeadError:
if error is None:
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error")
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment