test_async_llm_engine.py 6.12 KB
Newer Older
1
import asyncio
2
from asyncio import CancelledError
3
from dataclasses import dataclass
4
from typing import Optional
5
6

import pytest
7
import pytest_asyncio
8
import torch
9

10
from vllm import SamplingParams
11
from vllm.config import ParallelConfig
12
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
13
from vllm.outputs import RequestOutput as RealRequestOutput
14

15
from ..conftest import cleanup
16
from ..utils import wait_for_gpu_memory_to_clear
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
32
33
        # Ugly, remove dependency when possible
        self.parallel_config = ParallelConfig(1, 1, False)
34

35
36
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
37
38
39
40
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

41
    async def process_model_inputs_async(self, *args, **kwargs):
42
        pass
43

44
45
46
    async def stop_remote_worker_execution_loop_async(self):
        pass

47
48
49
50
51
52
53
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
54
        del kwargs  # Unused
55
        self.add_request_calls += 1
56
        print(f'Request calls: {self.add_request_calls}')
57

58
59
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
60
        return
61

62
    def abort_request(self, request_id):
63
        del request_id  # Unused
64
65
        self.abort_request_calls += 1

66
67
68
    def has_unfinished_requests(self):
        return self.request_id is not None

69
70
71
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

72
73

class MockAsyncLLMEngine(AsyncLLMEngine):
74
    _engine_class = MockEngine
75
76
77
78


@pytest.mark.asyncio
async def test_new_requests_event():
79
    engine = MockAsyncLLMEngine(worker_use_ray=False)
80
81
82
83
84
85
86
87
88
89
90
91
92
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

    await engine.add_request("1", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

    await engine.add_request("2", "", None)
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
93
    await asyncio.sleep(0)
94
95
96
97
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
98
    engine.engine.stop_generating()
99
100
101
102
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
103
104
105
106

    await engine.add_request("3", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
107
    assert engine.engine.step_calls == old_step_calls + 1
108
109
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
110
    assert engine.engine.step_calls == old_step_calls + 1
111

112
    engine = MockAsyncLLMEngine(worker_use_ray=True)
113
    assert engine.get_model_config() is not None
114
    assert engine.get_tokenizer() is not None
115
    assert engine.get_decoding_config() is not None
116
117


118
def start_engine():
119
120
121
122
123
124
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    return AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))


@pytest_asyncio.fixture(scope="module")
async def async_engine():
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
        cleanup()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
150
151
152
153
154
155
156

    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
        )

157
158
159
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
                                                  request_id=prompt):
160
161
162
            final_output = output
        return final_output

163
164
165
166
    results = await asyncio.gather(
        run("test0"),
        run("test1"),
    )
167
    assert len(results) == 2
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
                                                  request_id="test2"):
            assert not output.finished
            i += 1
            if i == 5:
                await async_engine.abort("test2")

    assert i == 5


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    stream = async_engine.generate("test3",
                                   sampling_params,
                                   request_id="test3")
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished