test_async_llm_engine.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
6
import os
import uuid
7
from asyncio import CancelledError
8
from copy import copy
9
10
from dataclasses import dataclass, field
from typing import Any, Optional
11
12

import pytest
13
import pytest_asyncio
14
import torch
15

16
from vllm import SamplingParams
17
from vllm.config import ParallelConfig
18
from vllm.distributed import cleanup_dist_env_and_memory
19
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
20
from vllm.outputs import RequestOutput as RealRequestOutput
21
from vllm.sampling_params import RequestOutputKind
22
23

from ..utils import wait_for_gpu_memory_to_clear
24
25
26
27
28
29
30
31


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


32
33
34
@dataclass
class MockModelConfig:
    use_async_output_proc = True
35
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
36
37


38
39
40
41
42
43
44
class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
45
        # Ugly, remove dependency when possible
46
        self.parallel_config = ParallelConfig()
47
        self.model_config = MockModelConfig()
48

49
50
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
51
52
53
54
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

55
    async def process_model_inputs_async(self, *args, **kwargs):
56
        pass
57

58
59
60
    async def stop_remote_worker_execution_loop_async(self):
        pass

61
62
63
64
65
66
67
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
68
        del kwargs  # Unused
69
        self.add_request_calls += 1
70
        print(f'Request calls: {self.add_request_calls}')
71

72
73
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
74
        return
75

76
    def abort_request(self, request_id):
77
        del request_id  # Unused
78
79
        self.abort_request_calls += 1

80
81
82
    def has_unfinished_requests(self):
        return self.request_id is not None

83
84
85
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

86
87

class MockAsyncLLMEngine(AsyncLLMEngine):
88
    _engine_class = MockEngine
89
90
91
92


@pytest.mark.asyncio
async def test_new_requests_event():
93
94
    params = SamplingParams()

95
    engine = MockAsyncLLMEngine()
96
97
98
99
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

100
    await engine.add_request("1", "", params)
101
102
103
104
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

105
    await engine.add_request("2", "", params)
106
107
108
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
109
    await asyncio.sleep(0)
110
111
112
113
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
114
    engine.engine.stop_generating()
115
116
117
118
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
119

120
    await engine.add_request("3", "", params)
121
122
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
123
    assert engine.engine.step_calls == old_step_calls + 1
124
125
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
126
    assert engine.engine.step_calls == old_step_calls + 1
127

128
    engine = MockAsyncLLMEngine()
129
    assert engine.get_model_config() is not None
130
    assert engine.get_tokenizer() is not None
131
    assert engine.get_decoding_config() is not None
132
133


134
def start_engine():
135
136
137
138
139
140
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

141
142
143
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

144
    return AsyncLLMEngine.from_engine_args(
145
146
147
148
149
150
151
        AsyncEngineArgs(model="facebook/opt-125m",
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
152
153
154
155


@pytest_asyncio.fixture(scope="module")
async def async_engine():
156
157
158
159
    # We cannot use monkeypatch since this is a module
    # scoped fixture and monkeypatch is function scoped.
    previous_value = os.getenv("VLLM_USE_V1", None)
    os.environ["VLLM_USE_V1"] = "0"
160
161
162
163
164
165
166
167
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
168
        cleanup_dist_env_and_memory()
169

170
171
172
173
174
        if previous_value:
            os.environ["VLLM_USE_V1"] = previous_value
        else:
            del os.environ["VLLM_USE_V1"]

175
176
177
178
179
180
181
182

@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
183
184
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
185

186
187
188
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

189
190
191
192
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
193
            min_tokens=32,
194
            stop=stop,
195
196
        )

197
198
        output_count = 0
        final_output = None
199
200
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
201
202
                                                  request_id=uid()):
            output_count += 1
203
            final_output = output
204
        return final_output, output_count
205

206
207
    results = await asyncio.gather(
        run("test0"),
208
        run("test0"),
209
    )
210
    assert len(results) == 2
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
229
230
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
231
232
233
234
235
236
237
238
239
240
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
241
        stop=stop,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
257
258
        assert final_output.finished

259
260
261
262
263
264
265
266
267
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
268
        output_tokens: list[int] = []
269
270
        output_text = ""
        output_count = 0
271
        final_output = None
272
273
274
275
276
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
277
            final_output = output
278
279
280
281

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
282
                assert stop or text
283
284
285
286
287
288
289
290
291
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
292
293
294
295

        assert final_output is not None
        assert final_output.finished

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
324
325
326


@pytest.mark.asyncio(scope="module")
327
328
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
329
330
331
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

332
333
    sampling_params = SamplingParams(
        temperature=0,
334
335
        min_tokens=13,
        max_tokens=13,
336
        stop=stop,
337
338
    )

339
340
341
342
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

343
344
345
346
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
347
                                                  request_id=request_id):
348
349
            assert not output.finished
            i += 1
350
351
            if i == stop_at:
                await async_engine.abort(request_id)
352

353
    assert i == stop_at
354
355
356


@pytest.mark.asyncio(scope="module")
357
358
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
359
360
361
362
363
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

364
365
366
367
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
368
        stop=stop,
369
370
    )

371
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409


@pytest.mark.asyncio(scope="module")
async def test_invalid_argument(async_engine):
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    # Targeting specific DP rank only supported in v1 multi-instance DP
    with pytest.raises(ValueError):
        async for _ in async_engine.generate("test",
                                             sampling_params,
                                             request_id=uid(),
                                             data_parallel_rank=0):
            pass