test_async_llm_engine.py 10.8 KB
Newer Older
1
import asyncio
2
3
import os
import uuid
4
from asyncio import CancelledError
5
from copy import copy
6
from dataclasses import dataclass
7
from typing import List, Optional
8
9

import pytest
10
import pytest_asyncio
11
import torch
12

13
from vllm import SamplingParams
14
from vllm.config import ParallelConfig
15
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
16
from vllm.outputs import RequestOutput as RealRequestOutput
17
from vllm.sampling_params import RequestOutputKind
18

19
from ..conftest import cleanup
20
from ..utils import wait_for_gpu_memory_to_clear
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
36
37
        # Ugly, remove dependency when possible
        self.parallel_config = ParallelConfig(1, 1, False)
38

39
40
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
41
42
43
44
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

45
    async def process_model_inputs_async(self, *args, **kwargs):
46
        pass
47

48
49
50
    async def stop_remote_worker_execution_loop_async(self):
        pass

51
52
53
54
55
56
57
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
58
        del kwargs  # Unused
59
        self.add_request_calls += 1
60
        print(f'Request calls: {self.add_request_calls}')
61

62
63
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
64
        return
65

66
    def abort_request(self, request_id):
67
        del request_id  # Unused
68
69
        self.abort_request_calls += 1

70
71
72
    def has_unfinished_requests(self):
        return self.request_id is not None

73
74
75
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

76
77

class MockAsyncLLMEngine(AsyncLLMEngine):
78
    _engine_class = MockEngine
79
80
81
82


@pytest.mark.asyncio
async def test_new_requests_event():
83
    engine = MockAsyncLLMEngine(worker_use_ray=False)
84
85
86
87
88
89
90
91
92
93
94
95
96
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

    await engine.add_request("1", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

    await engine.add_request("2", "", None)
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
97
    await asyncio.sleep(0)
98
99
100
101
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
102
    engine.engine.stop_generating()
103
104
105
106
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
107
108
109
110

    await engine.add_request("3", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
111
    assert engine.engine.step_calls == old_step_calls + 1
112
113
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
114
    assert engine.engine.step_calls == old_step_calls + 1
115

116
    engine = MockAsyncLLMEngine(worker_use_ray=True)
117
    assert engine.get_model_config() is not None
118
    assert engine.get_tokenizer() is not None
119
    assert engine.get_decoding_config() is not None
120
121


122
def start_engine():
123
124
125
126
127
128
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

129
130
131
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

132
    return AsyncLLMEngine.from_engine_args(
133
134
135
136
137
138
139
        AsyncEngineArgs(model="facebook/opt-125m",
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


@pytest_asyncio.fixture(scope="module")
async def async_engine():
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
        cleanup()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
162
163
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
164

165
166
167
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

168
169
170
171
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
172
            min_tokens=32,
173
            stop=stop,
174
175
        )

176
177
        output_count = 0
        final_output = None
178
179
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
180
181
                                                  request_id=uid()):
            output_count += 1
182
            final_output = output
183
        return final_output, output_count
184

185
186
    results = await asyncio.gather(
        run("test0"),
187
        run("test0"),
188
    )
189
    assert len(results) == 2
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
208
209
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
210
211
212
213
214
215
216
217
218
219
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
220
        stop=stop,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
236
237
        assert final_output.finished

238
239
240
241
242
243
244
245
246
247
248
249
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
        output_tokens: List[int] = []
        output_text = ""
        output_count = 0
250
        final_output = None
251
252
253
254
255
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
256
            final_output = output
257
258
259
260

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
261
                assert stop or text
262
263
264
265
266
267
268
269
270
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
271
272
273
274

        assert final_output is not None
        assert final_output.finished

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
303
304
305


@pytest.mark.asyncio(scope="module")
306
307
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
308
309
310
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

311
312
    sampling_params = SamplingParams(
        temperature=0,
313
314
        min_tokens=13,
        max_tokens=13,
315
        stop=stop,
316
317
    )

318
319
320
321
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

322
323
324
325
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
326
                                                  request_id=request_id):
327
328
            assert not output.finished
            i += 1
329
330
            if i == stop_at:
                await async_engine.abort(request_id)
331

332
    assert i == stop_at
333
334
335


@pytest.mark.asyncio(scope="module")
336
337
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
338
339
340
341
342
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

343
344
345
346
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
347
        stop=stop,
348
349
    )

350
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished