test_async_llm_engine.py 11.1 KB
Newer Older
1
import asyncio
2
3
import os
import uuid
4
from asyncio import CancelledError
5
from copy import copy
6
from dataclasses import dataclass
7
from typing import List, Optional
8
9

import pytest
10
import pytest_asyncio
11
import torch
12

13
from vllm import SamplingParams
14
from vllm.config import ParallelConfig
15
from vllm.distributed import cleanup_dist_env_and_memory
16
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
17
from vllm.outputs import RequestOutput as RealRequestOutput
18
from vllm.sampling_params import RequestOutputKind
19
20

from ..utils import wait_for_gpu_memory_to_clear
21
22
import os
from ..utils import models_path_prefix
23
24
25
26
27
28
29

@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


30
31
32
33
34
@dataclass
class MockModelConfig:
    use_async_output_proc = True


35
36
37
38
39
40
41
class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
42
43
        # Ugly, remove dependency when possible
        self.parallel_config = ParallelConfig(1, 1, False)
44
        self.model_config = MockModelConfig()
45

46
47
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
48
49
50
51
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

52
    async def process_model_inputs_async(self, *args, **kwargs):
53
        pass
54

55
56
57
    async def stop_remote_worker_execution_loop_async(self):
        pass

58
59
60
61
62
63
64
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
65
        del kwargs  # Unused
66
        self.add_request_calls += 1
67
        print(f'Request calls: {self.add_request_calls}')
68

69
70
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
71
        return
72

73
    def abort_request(self, request_id):
74
        del request_id  # Unused
75
76
        self.abort_request_calls += 1

77
78
79
    def has_unfinished_requests(self):
        return self.request_id is not None

80
81
82
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

83
84

class MockAsyncLLMEngine(AsyncLLMEngine):
85
    _engine_class = MockEngine
86
87
88
89


@pytest.mark.asyncio
async def test_new_requests_event():
90
91
    params = SamplingParams()

92
    engine = MockAsyncLLMEngine()
93
94
95
96
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

97
    await engine.add_request("1", "", params)
98
99
100
101
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

102
    await engine.add_request("2", "", params)
103
104
105
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
106
    await asyncio.sleep(0)
107
108
109
110
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
111
    engine.engine.stop_generating()
112
113
114
115
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
116

117
    await engine.add_request("3", "", params)
118
119
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
120
    assert engine.engine.step_calls == old_step_calls + 1
121
122
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
123
    assert engine.engine.step_calls == old_step_calls + 1
124

125
    engine = MockAsyncLLMEngine()
126
    assert engine.get_model_config() is not None
127
    assert engine.get_tokenizer() is not None
128
    assert engine.get_decoding_config() is not None
129
130


131
def start_engine():
132
133
134
135
136
137
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

138
139
140
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

141
    return AsyncLLMEngine.from_engine_args(
142
        AsyncEngineArgs(model=os.path.join(models_path_prefix, "facebook/opt-125m"),
143
144
145
146
147
148
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
149
150
151
152
153
154
155
156
157
158
159
160


@pytest_asyncio.fixture(scope="module")
async def async_engine():
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
161
        cleanup_dist_env_and_memory()
162
163
164
165
166
167
168
169
170


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
171
172
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
173

174
175
176
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

177
178
179
180
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
181
            min_tokens=32,
182
            stop=stop,
183
184
        )

185
186
        output_count = 0
        final_output = None
187
188
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
189
190
                                                  request_id=uid()):
            output_count += 1
191
            final_output = output
192
        return final_output, output_count
193

194
195
    results = await asyncio.gather(
        run("test0"),
196
        run("test0"),
197
    )
198
    assert len(results) == 2
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
217
218
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
219
220
221
222
223
224
225
226
227
228
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
229
        stop=stop,
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
245
246
        assert final_output.finished

247
248
249
250
251
252
253
254
255
256
257
258
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
        output_tokens: List[int] = []
        output_text = ""
        output_count = 0
259
        final_output = None
260
261
262
263
264
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
265
            final_output = output
266
267
268
269

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
270
                assert stop or text
271
272
273
274
275
276
277
278
279
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
280
281
282
283

        assert final_output is not None
        assert final_output.finished

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
312
313
314


@pytest.mark.asyncio(scope="module")
315
316
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
317
318
319
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

320
321
    sampling_params = SamplingParams(
        temperature=0,
322
323
        min_tokens=13,
        max_tokens=13,
324
        stop=stop,
325
326
    )

327
328
329
330
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

331
332
333
334
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
335
                                                  request_id=request_id):
336
337
            assert not output.finished
            i += 1
338
339
            if i == stop_at:
                await async_engine.abort(request_id)
340

341
    assert i == stop_at
342
343
344


@pytest.mark.asyncio(scope="module")
345
346
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
347
348
349
350
351
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

352
353
354
355
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
356
        stop=stop,
357
358
    )

359
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished