test_async_llm_engine.py 6.37 KB
Newer Older
1
import asyncio
2
import os
3
from asyncio import CancelledError
4
from dataclasses import dataclass
5
from typing import Optional
6
7

import pytest
8
import pytest_asyncio
9
import torch
10

11
from vllm import SamplingParams
12
from vllm.config import ParallelConfig
13
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
14
from vllm.outputs import RequestOutput as RealRequestOutput
15

16
from ..conftest import cleanup
17
from ..utils import wait_for_gpu_memory_to_clear
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
33
34
        # Ugly, remove dependency when possible
        self.parallel_config = ParallelConfig(1, 1, False)
35

36
37
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
38
39
40
41
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

42
    async def process_model_inputs_async(self, *args, **kwargs):
43
        pass
44

45
46
47
    async def stop_remote_worker_execution_loop_async(self):
        pass

48
49
50
51
52
53
54
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
55
        del kwargs  # Unused
56
        self.add_request_calls += 1
57
        print(f'Request calls: {self.add_request_calls}')
58

59
60
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
61
        return
62

63
    def abort_request(self, request_id):
64
        del request_id  # Unused
65
66
        self.abort_request_calls += 1

67
68
69
    def has_unfinished_requests(self):
        return self.request_id is not None

70
71
72
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

class MockAsyncLLMEngine(AsyncLLMEngine):

    def _init_engine(self, *args, **kwargs):
        return MockEngine()


@pytest.mark.asyncio
async def test_new_requests_event():
    engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

    await engine.add_request("1", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

    await engine.add_request("2", "", None)
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
96
    await asyncio.sleep(0)
97
98
99
100
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
101
    engine.engine.stop_generating()
102
103
104
105
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
106
107
108
109

    await engine.add_request("3", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
110
    assert engine.engine.step_calls == old_step_calls + 1
111
112
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
113
    assert engine.engine.step_calls == old_step_calls + 1
114

115
116
117
    # Allow deprecated engine_use_ray to not raise exception
    os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"

118
    engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
119
    assert engine.get_model_config() is not None
120
    assert engine.get_tokenizer() is not None
121
    assert engine.get_decoding_config() is not None
122

123
124
    os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")

125

126
def start_engine():
127
128
129
130
131
132
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    return AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))


@pytest_asyncio.fixture(scope="module")
async def async_engine():
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
        cleanup()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
158
159
160
161
162
163
164

    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
        )

165
166
167
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
                                                  request_id=prompt):
168
169
170
            final_output = output
        return final_output

171
172
173
174
    results = await asyncio.gather(
        run("test0"),
        run("test1"),
    )
175
    assert len(results) == 2
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
                                                  request_id="test2"):
            assert not output.finished
            i += 1
            if i == 5:
                await async_engine.abort("test2")

    assert i == 5


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    stream = async_engine.generate("test3",
                                   sampling_params,
                                   request_id="test3")
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished