test_async_llm_engine.py 11.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
5
import os
import uuid
6
from asyncio import CancelledError
7
from copy import copy
8
from dataclasses import dataclass
9
from typing import Optional
10
11

import pytest
12
import pytest_asyncio
13
import torch
14

15
from vllm import SamplingParams
16
from vllm.config import ParallelConfig
17
from vllm.distributed import cleanup_dist_env_and_memory
18
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
19
from vllm.outputs import RequestOutput as RealRequestOutput
20
from vllm.sampling_params import RequestOutputKind
21
22

from ..utils import wait_for_gpu_memory_to_clear
23
24
25
26
27
28
29
30


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


31
32
33
34
35
@dataclass
class MockModelConfig:
    use_async_output_proc = True


36
37
38
39
40
41
42
class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
43
        # Ugly, remove dependency when possible
44
        self.parallel_config = ParallelConfig()
45
        self.model_config = MockModelConfig()
46

47
48
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
49
50
51
52
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

53
    async def process_model_inputs_async(self, *args, **kwargs):
54
        pass
55

56
57
58
    async def stop_remote_worker_execution_loop_async(self):
        pass

59
60
61
62
63
64
65
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
66
        del kwargs  # Unused
67
        self.add_request_calls += 1
68
        print(f'Request calls: {self.add_request_calls}')
69

70
71
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
72
        return
73

74
    def abort_request(self, request_id):
75
        del request_id  # Unused
76
77
        self.abort_request_calls += 1

78
79
80
    def has_unfinished_requests(self):
        return self.request_id is not None

81
82
83
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

84
85

class MockAsyncLLMEngine(AsyncLLMEngine):
86
    _engine_class = MockEngine
87
88
89
90


@pytest.mark.asyncio
async def test_new_requests_event():
91
92
    params = SamplingParams()

93
    engine = MockAsyncLLMEngine()
94
95
96
97
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

98
    await engine.add_request("1", "", params)
99
100
101
102
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

103
    await engine.add_request("2", "", params)
104
105
106
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
107
    await asyncio.sleep(0)
108
109
110
111
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
112
    engine.engine.stop_generating()
113
114
115
116
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
117

118
    await engine.add_request("3", "", params)
119
120
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
121
    assert engine.engine.step_calls == old_step_calls + 1
122
123
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
124
    assert engine.engine.step_calls == old_step_calls + 1
125

126
    engine = MockAsyncLLMEngine()
127
    assert engine.get_model_config() is not None
128
    assert engine.get_tokenizer() is not None
129
    assert engine.get_decoding_config() is not None
130
131


132
def start_engine():
133
134
135
136
137
138
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

139
140
141
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

142
    return AsyncLLMEngine.from_engine_args(
143
144
145
146
147
148
149
        AsyncEngineArgs(model="facebook/opt-125m",
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
150
151
152
153


@pytest_asyncio.fixture(scope="module")
async def async_engine():
154
155
156
157
    # We cannot use monkeypatch since this is a module
    # scoped fixture and monkeypatch is function scoped.
    previous_value = os.getenv("VLLM_USE_V1", None)
    os.environ["VLLM_USE_V1"] = "0"
158
159
160
161
162
163
164
165
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
166
        cleanup_dist_env_and_memory()
167

168
169
170
171
172
        if previous_value:
            os.environ["VLLM_USE_V1"] = previous_value
        else:
            del os.environ["VLLM_USE_V1"]

173
174
175
176
177
178
179
180

@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
181
182
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
183

184
185
186
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

187
188
189
190
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
191
            min_tokens=32,
192
            stop=stop,
193
194
        )

195
196
        output_count = 0
        final_output = None
197
198
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
199
200
                                                  request_id=uid()):
            output_count += 1
201
            final_output = output
202
        return final_output, output_count
203

204
205
    results = await asyncio.gather(
        run("test0"),
206
        run("test0"),
207
    )
208
    assert len(results) == 2
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
227
228
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
229
230
231
232
233
234
235
236
237
238
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
239
        stop=stop,
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
255
256
        assert final_output.finished

257
258
259
260
261
262
263
264
265
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
266
        output_tokens: list[int] = []
267
268
        output_text = ""
        output_count = 0
269
        final_output = None
270
271
272
273
274
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
275
            final_output = output
276
277
278
279

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
280
                assert stop or text
281
282
283
284
285
286
287
288
289
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
290
291
292
293

        assert final_output is not None
        assert final_output.finished

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
322
323
324


@pytest.mark.asyncio(scope="module")
325
326
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
327
328
329
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

330
331
    sampling_params = SamplingParams(
        temperature=0,
332
333
        min_tokens=13,
        max_tokens=13,
334
        stop=stop,
335
336
    )

337
338
339
340
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

341
342
343
344
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
345
                                                  request_id=request_id):
346
347
            assert not output.finished
            i += 1
348
349
            if i == stop_at:
                await async_engine.abort(request_id)
350

351
    assert i == stop_at
352
353
354


@pytest.mark.asyncio(scope="module")
355
356
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
357
358
359
360
361
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

362
363
364
365
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
366
        stop=stop,
367
368
    )

369
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished