test_async_llm_engine.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
6
import os
import uuid
7
from asyncio import CancelledError
8
from copy import copy
9
10
from dataclasses import dataclass, field
from typing import Any, Optional
11
12

import pytest
13
import pytest_asyncio
14
import torch
15

16
from vllm import SamplingParams
17
from vllm.config import ParallelConfig
18
from vllm.distributed import cleanup_dist_env_and_memory
19
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
20
from vllm.outputs import RequestOutput as RealRequestOutput
21
from vllm.sampling_params import RequestOutputKind
22
23

from ..utils import wait_for_gpu_memory_to_clear
24
25
26
27
28
29
30
31


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


32
33
34
@dataclass
class MockModelConfig:
    use_async_output_proc = True
35
36
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
    mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
37
38


39
40
41
42
43
44
45
class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
46
        # Ugly, remove dependency when possible
47
        self.parallel_config = ParallelConfig()
48
        self.model_config = MockModelConfig()
49

50
51
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
52
53
54
55
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

56
    async def process_model_inputs_async(self, *args, **kwargs):
57
        pass
58

59
60
61
    async def stop_remote_worker_execution_loop_async(self):
        pass

62
63
64
65
66
67
68
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
69
        del kwargs  # Unused
70
        self.add_request_calls += 1
71
        print(f'Request calls: {self.add_request_calls}')
72

73
74
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
75
        return
76

77
    def abort_request(self, request_id):
78
        del request_id  # Unused
79
80
        self.abort_request_calls += 1

81
82
83
    def has_unfinished_requests(self):
        return self.request_id is not None

84
85
86
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

87
88

class MockAsyncLLMEngine(AsyncLLMEngine):
89
    _engine_class = MockEngine
90
91
92
93


@pytest.mark.asyncio
async def test_new_requests_event():
94
95
    params = SamplingParams()

96
    engine = MockAsyncLLMEngine()
97
98
99
100
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

101
    await engine.add_request("1", "", params)
102
103
104
105
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

106
    await engine.add_request("2", "", params)
107
108
109
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
110
    await asyncio.sleep(0)
111
112
113
114
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
115
    engine.engine.stop_generating()
116
117
118
119
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
120

121
    await engine.add_request("3", "", params)
122
123
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
124
    assert engine.engine.step_calls == old_step_calls + 1
125
126
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
127
    assert engine.engine.step_calls == old_step_calls + 1
128

129
    engine = MockAsyncLLMEngine()
130
    assert engine.get_model_config() is not None
131
    assert engine.get_tokenizer() is not None
132
    assert engine.get_decoding_config() is not None
133
134


135
def start_engine():
136
137
138
139
140
141
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

142
143
144
    num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
    print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

145
    return AsyncLLMEngine.from_engine_args(
146
147
148
149
150
151
152
        AsyncEngineArgs(model="facebook/opt-125m",
                        enforce_eager=True,
                        num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
    return str(uuid.uuid4())
153
154
155
156


@pytest_asyncio.fixture(scope="module")
async def async_engine():
157
158
159
160
    # We cannot use monkeypatch since this is a module
    # scoped fixture and monkeypatch is function scoped.
    previous_value = os.getenv("VLLM_USE_V1", None)
    os.environ["VLLM_USE_V1"] = "0"
161
162
163
164
165
166
167
168
    engine = await asyncio.get_event_loop().run_in_executor(executor=None,
                                                            func=start_engine)
    try:
        yield engine
    finally:
        engine.shutdown_background_loop()
        del engine
        await asyncio.sleep(0.1)
169
        cleanup_dist_env_and_memory()
170

171
172
173
174
175
        if previous_value:
            os.environ["VLLM_USE_V1"] = previous_value
        else:
            del os.environ["VLLM_USE_V1"]

176
177
178
179
180
181
182
183

@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
    # So we can share the async engine fixture between these tests
    return False


@pytest.mark.asyncio(scope="module")
184
185
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
186

187
188
189
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

190
191
192
193
    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
194
            min_tokens=32,
195
            stop=stop,
196
197
        )

198
199
        output_count = 0
        final_output = None
200
201
        async for output in async_engine.generate(prompt,
                                                  sampling_params,
202
203
                                                  request_id=uid()):
            output_count += 1
204
            final_output = output
205
        return final_output, output_count
206

207
208
    results = await asyncio.gather(
        run("test0"),
209
        run("test0"),
210
    )
211
    assert len(results) == 2
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    first, second = results

    # remove nondeterministic fields for comparison
    first[0].metrics = None
    second[0].metrics = None
    first[0].request_id = None
    second[0].request_id = None

    assert str(first) == str(second)

    output_count = results[0][1]
    if num_scheduler_steps == 1:
        assert output_count == 32
    else:
        assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
230
231
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
232
233
234
235
236
237
238
239
240
241
    """Test that output_kind works as expected and that
    results are equivalent across different kinds."""

    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

    sampling_params = SamplingParams(
        temperature=0,
        max_tokens=32,
        min_tokens=32,
242
        stop=stop,
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    )

    async def run(prompt: str, kind: RequestOutputKind):
        params = copy(sampling_params)
        params.output_kind = kind

        output_count = 0
        final_output = None
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            output_count += 1
            final_output = output

        assert final_output is not None
258
259
        assert final_output.finished

260
261
262
263
264
265
266
267
268
        return (final_output.prompt_token_ids,
                final_output.outputs[0].token_ids,
                final_output.outputs[0].text, output_count)

    async def run_deltas(prompt: str):
        params = copy(sampling_params)
        params.output_kind = RequestOutputKind.DELTA

        prompt_tokens = None
269
        output_tokens: list[int] = []
270
271
        output_text = ""
        output_count = 0
272
        final_output = None
273
274
275
276
277
        async for output in async_engine.generate(prompt,
                                                  params,
                                                  request_id=uid()):
            token_ids = output.outputs[0].token_ids
            text = output.outputs[0].text
278
            final_output = output
279
280
281
282

            # Ensure we get prompt ids iff we haven't yet received output tokens
            if output_tokens:
                assert 1 <= len(token_ids) <= num_scheduler_steps
283
                assert stop or text
284
285
286
287
288
289
290
291
292
                assert not output.prompt_token_ids
            else:
                assert output.prompt_token_ids
                prompt_tokens = output.prompt_token_ids

            output_tokens.extend(token_ids)
            output_text += text

            output_count += 1
293
294
295
296

        assert final_output is not None
        assert final_output.finished

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        return prompt_tokens, output_tokens, output_text, output_count

    results = await asyncio.gather(
        run("common input prompt", RequestOutputKind.CUMULATIVE),
        run("common input prompt", RequestOutputKind.FINAL_ONLY),
        run_deltas("common input prompt"))

    # Make sure outputs are the same
    prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
    assert len(prompt_set) == 1

    text_set = set(text for _, _, text, _ in results)
    assert len(text_set) == 1

    tokens_set = set(tuple(ids) for _, ids, _, _ in results)
    assert len(tokens_set) == 1

    cumulative, final, deltas = results

    # output message counts
    assert cumulative[3] == deltas[3]

    if num_scheduler_steps == 1:
        assert cumulative[3] == 32
    else:
        assert 1 < cumulative[3] < 32

    assert final[3] == 1
325
326
327


@pytest.mark.asyncio(scope="module")
328
329
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
330
331
332
    scheduler_config = await async_engine.get_scheduler_config()
    num_scheduler_steps = scheduler_config.num_scheduler_steps

333
334
    sampling_params = SamplingParams(
        temperature=0,
335
336
        min_tokens=13,
        max_tokens=13,
337
        stop=stop,
338
339
    )

340
341
342
343
    stop_at = 5 if num_scheduler_steps == 1 else 1

    request_id = uid()

344
345
346
347
    i = 0
    with pytest.raises(CancelledError):
        async for output in async_engine.generate("test2",
                                                  sampling_params,
348
                                                  request_id=request_id):
349
350
            assert not output.finished
            i += 1
351
352
            if i == stop_at:
                await async_engine.abort(request_id)
353

354
    assert i == stop_at
355
356
357


@pytest.mark.asyncio(scope="module")
358
359
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
360
361
362
363
364
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

365
366
367
368
    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
369
        stop=stop,
370
371
    )

372
    stream = async_engine.generate("test3", sampling_params, request_id=uid())
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    i = 0
    final_output: Optional[RealRequestOutput] = None
    async for output in stream:
        final_output = output
        if i == 0:
            # wait for generation to complete before consuming
            # the remaining messages
            await asyncio.sleep(1)
        if i < 9:
            assert not output.finished
        i += 1

    assert i == 10
    assert final_output is not None
    assert len(final_output.outputs[0].token_ids) == 10
    assert final_output.finished
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410


@pytest.mark.asyncio(scope="module")
async def test_invalid_argument(async_engine):
    scheduler_config = await async_engine.get_scheduler_config()

    if scheduler_config.num_scheduler_steps != 1:
        pytest.skip("no need to test this one with multistep")

    sampling_params = SamplingParams(
        temperature=0,
        min_tokens=10,
        max_tokens=10,
    )

    # Targeting specific DP rank only supported in v1 multi-instance DP
    with pytest.raises(ValueError):
        async for _ in async_engine.generate("test",
                                             sampling_params,
                                             request_id=uid(),
                                             data_parallel_rank=0):
            pass