test_zmq_client.py 3.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid

import pytest
import pytest_asyncio

from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
                                                RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer


@pytest.fixture(scope="function")
def tmp_socket():
    with tempfile.TemporaryDirectory() as td:
        yield f"ipc://{td}/{uuid.uuid4()}"


@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
    dummy_engine = unittest.mock.AsyncMock()

    def dummy_engine_builder(*args, **kwargs):
        return dummy_engine

    with monkeypatch.context() as m:
        m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
        server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)

    loop = asyncio.get_running_loop()
    server_task = loop.create_task(server.run_server_loop())

    try:
        yield server
    finally:
        server_task.cancel()
        server.cleanup()


@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
    client = AsyncEngineRPCClient(rpc_path=tmp_socket)
    # Sanity check: the server is connected
    await client._wait_for_server_rpc()

    try:
        yield client
    finally:
        client.close()


@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
                                                client: AsyncEngineRPCClient):
    with monkeypatch.context() as m:
        # Make the server _not_ reply with a model config
        m.setattr(dummy_server, "get_config", lambda x: None)
        m.setattr(client, "_data_timeout", 10)

        # And ensure the task completes anyway
        # (client.setup() invokes server.get_config())
        client_task = asyncio.get_running_loop().create_task(client.setup())
        with pytest.raises(TimeoutError, match="Server didn't reply within"):
            await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
                                          client: AsyncEngineRPCClient):
    with monkeypatch.context() as m:
        # Hang all abort requests
        m.setattr(dummy_server, "abort", lambda x: None)
        m.setattr(client, "_data_timeout", 10)

78
79
80
        # The client should suppress timeouts on `abort`s
        # and return normally, assuming the server will eventually
        # abort the request.
81
82
        client_task = asyncio.get_running_loop().create_task(
            client.abort("test request id"))
83
        await asyncio.wait_for(client_task, timeout=0.05)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120


@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
        monkeypatch, dummy_server, client: AsyncEngineRPCClient):
    with monkeypatch.context() as m:
        # Make the server raise some random exception
        exception = RuntimeError("Client test exception")

        def raiser():
            raise exception

        m.setattr(dummy_server.engine, "get_model_config", raiser)
        m.setattr(client, "_data_timeout", 10)

        client_task = asyncio.get_running_loop().create_task(client.setup())
        # And ensure the task completes, raising the exception
        with pytest.raises(RuntimeError, match=str(exception)):
            await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
                                           client: AsyncEngineRPCClient):

    client.close()

    # Healthchecks and generate requests will fail with explicit errors
    with pytest.raises(RPCClientClosedError):
        await client.check_health()
    with pytest.raises(RPCClientClosedError):
        async for _ in client.generate(None, None, None):
            pass

    # But no-ops like aborting will pass
    await client.abort("test-request-id")
    await client.do_log_stats()