test_async_llm_engine.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
6
import os
import uuid
7
from asyncio import CancelledError
8
from copy import copy
9
from dataclasses import dataclass
10
from typing import Optional
11
12

import pytest
13
import pytest_asyncio
14
import torch
15

16
from vllm import SamplingParams
17
from vllm.config import ParallelConfig
18
from vllm.distributed import cleanup_dist_env_and_memory
19
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
20
from vllm.outputs import RequestOutput as RealRequestOutput
21
from vllm.sampling_params import RequestOutputKind
22
23

from ..utils import wait_for_gpu_memory_to_clear
24
25
26
27
28
29
30
31


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


32
33
34
35
36
@dataclass
class MockModelConfig:
    use_async_output_proc = True


37
38
39
40
41
42
43
class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
44
        # Ugly, remove dependency when possible
45
        self.parallel_config = ParallelConfig()
46
        self.model_config = MockModelConfig()
47

48
49
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
50
51
52
53
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

54
    async def process_model_inputs_async(self, *args, **kwargs):
55
        pass
56

57
58
59
    async def stop_remote_worker_execution_loop_async(self):
        pass

60
61
62
63
64
65
66
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
67
        del kwargs  # Unused
68
        self.add_request_calls += 1
69
        print(f'Request calls: {self.add_request_calls}')
70

71
72
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
73
        return
74

75
    def abort_request(self, request_id):
76
        del request_id  # Unused
77
78
        self.abort_request_calls += 1

79
80
81
    def has_unfinished_requests(self):
        return self.request_id is not None

82
83
84
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

85
86

class MockAsyncLLMEngine(AsyncLLMEngine):
87
    _engine_class = MockEngine
88
89
90
91


@pytest.mark.asyncio
async def test_new_requests_event():
92
93
    params = SamplingParams()

94
    engine = MockAsyncLLMEngine()
95
96
97
98
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

99
    await engine.add_request("1", "", params)
100
101
102
103
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

104
    await engine.add_request("2", "", params)
105
106
107
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
108
    await asyncio.sleep(0)
109
110
111
112
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
113
    engine.engine.stop_generating()
114
115
116
117
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
118

119
    await engine.add_request("3", "", params)
120
121
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
122
    assert engine.engine.step_calls == old_step_calls + 1
123
124
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
125
    assert engine.engine.step_calls == old_step_calls + 1
126

127
    engine = MockAsyncLLMEngine()
128
    assert engine.get_model_config() is not None
129
    assert engine.get_tokenizer() is not None
130
    assert engine.get_decoding_config() is not None
131
132


133
def start_engine():
134
135
136
137
138
139
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

140
141
142
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

143
    return AsyncLLMEngine.from_engine_args(
144
145
146
147
148
149
150
        AsyncEngineArgs(model="facebook/opt-125m",
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
151
152
153
154


@pytest_asyncio.fixture(scope="module")
async def async_engine():
155
156
157
158
    # We cannot use monkeypatch since this is a module
    # scoped fixture and monkeypatch is function scoped.
    previous_value = os.getenv("VLLM_USE_V1", None)
    os.environ["VLLM_USE_V1"] = "0"
159
160
161
162
163
164
165
166
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
167
        cleanup_dist_env_and_memory()
168

169
170
171
172
173
        if previous_value:
            os.environ["VLLM_USE_V1"] = previous_value
        else:
            del os.environ["VLLM_USE_V1"]

174
175
176
177
178
179
180
181

@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
182
183
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
184

185
186
187
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

188
189
190
191
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
192
            min_tokens=32,
193
            stop=stop,
194
195
        )

196
197
        output_count = 0
        final_output = None
198
199
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
200
201
                                                  request_id=uid()):
            output_count += 1
202
            final_output = output
203
        return final_output, output_count
204

205
206
    results = await asyncio.gather(
        run("test0"),
207
        run("test0"),
208
    )
209
    assert len(results) == 2
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
228
229
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
230
231
232
233
234
235
236
237
238
239
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
240
        stop=stop,
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
256
257
        assert final_output.finished

258
259
260
261
262
263
264
265
266
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
267
        output_tokens: list[int] = []
268
269
        output_text = ""
        output_count = 0
270
        final_output = None
271
272
273
274
275
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
276
            final_output = output
277
278
279
280

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
281
                assert stop or text
282
283
284
285
286
287
288
289
290
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
291
292
293
294

        assert final_output is not None
        assert final_output.finished

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
323
324
325


@pytest.mark.asyncio(scope="module")
326
327
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
328
329
330
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

331
332
    sampling_params = SamplingParams(
        temperature=0,
333
334
        min_tokens=13,
        max_tokens=13,
335
        stop=stop,
336
337
    )

338
339
340
341
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

342
343
344
345
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
346
                                                  request_id=request_id):
347
348
            assert not output.finished
            i += 1
349
350
            if i == stop_at:
                await async_engine.abort(request_id)
351

352
    assert i == stop_at
353
354
355


@pytest.mark.asyncio(scope="module")
356
357
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
358
359
360
361
362
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

363
364
365
366
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
367
        stop=stop,
368
369
    )

370
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished