test_async_llm.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from typing import Optional
7
from unittest.mock import MagicMock
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.assets.image import ImageAsset
13
from vllm.config import VllmConfig
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.inputs import PromptType
16
from vllm.platforms import current_platform
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.v1.engine.async_llm import AsyncLLM
19
from vllm.v1.metrics.loggers import LoggingStatLogger
20
21
22
23
24

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

25
26
27
28
29
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
    disable_log_requests=True,
)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
                                     enforce_eager=True,
                                     disable_log_requests=True)

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
    "<|im_start|>assistant\n")
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
    "multi_modal_data": {
        "image": ImageAsset("stop_sign").pil_image
46
    },
47
}
48
49


50
51
52
53
54
55
56
57
58
59
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
    prompt_logprobs: Optional[int] = None,
    cancel_after: Optional[int] = None,
) -> tuple[int, str]:
60
61
62
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

63
    count = 0
64
65
66
67
68
69
70
71
72
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
73
    async for out in engine.generate(request_id=request_id,
74
                                     prompt=prompt,
75
76
                                     sampling_params=sampling_params):

77
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
78
79
80
81
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
82

83
84
85
86
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
87
88
89
90

    return count, request_id


91
92
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
93
94
95
96
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
97
@pytest.mark.asyncio
98
99
100
101
102
103
async def test_load(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
104
105
106
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
107
    with monkeypatch.context() as m, ExitStack() as after:
108
109
        m.setenv("VLLM_USE_V1", "1")

110
        engine = AsyncLLM.from_engine_args(engine_args)
111
        after.callback(engine.shutdown)
112

113
        NUM_REQUESTS = 100
114
115
116
117
118
119
120
121
122
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
123
                    generate(engine, request_id, prompt, output_kind,
124
                             NUM_EXPECTED_TOKENS)))
125
126

        # Confirm that we got all the EXPECTED tokens from the requests.
127
128
129
130
131
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
132
            num_generated_tokens, request_id = await task
133
134
135
136
137
138
139
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()


140
141
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
142
143
144
145
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
146
@pytest.mark.asyncio
147
148
149
150
151
152
async def test_abort(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
153

154
    with monkeypatch.context() as m, ExitStack() as after:
155
156
        m.setenv("VLLM_USE_V1", "1")

157
        engine = AsyncLLM.from_engine_args(engine_args)
158
        after.callback(engine.shutdown)
159
160
161

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
162
        NUM_EXPECTED_TOKENS_LONG = 50000
163
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
164
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
165
166
167
168

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
169
        tasks: list[asyncio.Task] = []
170
        for idx, request_id in enumerate(request_ids):
171
172
173
            max_tokens = (NUM_EXPECTED_TOKENS_LONG if
                          (idx
                           in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
174
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
175
176
            tasks.append(
                asyncio.create_task(
177
                    generate(engine, request_id, prompt, output_kind,
178
                             max_tokens, n)))
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
194
195
196
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
197
                    f"{request_id} generated {num_generated_tokens} but "
198
                    f"expected {expected_tokens}")
199

200
        # Make sure all aborted requests were really aborted.
201
202
203
204
205
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
206
207
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
208
209
210
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
211
212
213


@pytest.mark.parametrize("n", [1, 3])
214
215
216
217
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
218
@pytest.mark.asyncio
219
220
221
222
223
224
async def test_finished_flag(
    monkeypatch: pytest.MonkeyPatch,
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
225
226
227
228
229
230
231

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(engine_args)
        after.callback(engine.shutdown)

232
233
234
235
236
237
238
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
239
240
241
242
243
244
245
246
247
248
        outputs = [
            out
            async for out in engine.generate(request_id="request-33",
                                             prompt=prompt,
                                             sampling_params=sampling_params)
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
249
250


251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
                                       engine_args: AsyncEngineArgs,
                                       prompt: PromptType):
    """Test that requests can be cancelled mid-stream."""
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(engine_args)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
                    )))

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
                f"expected to cancel after {NUM_EXPECTED_TOKENS}")

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
            generate(engine, request_id, prompt, RequestOutputKind.DELTA,
                     NUM_EXPECTED_TOKENS))
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class MockLoggingStatLogger(LoggingStatLogger):

    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
    be used directly.
    """

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(
            TEXT_ENGINE_ARGS,
            stat_loggers=[MockLoggingStatLogger],
        )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        assert len(engine.stat_loggers) == 1
        assert len(engine.stat_loggers[0]) == 1
        engine.stat_loggers[0][0].log.assert_called_once()
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364


@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=100,
                                         output_kind=RequestOutputKind.DELTA,
                                         temperature=1.0,
                                         seed=33)

        # Test with valid DP rank.
        async for _ in engine.generate(request_id="request-34",
                                       prompt=TEXT_PROMPT,
                                       sampling_params=sampling_params,
                                       data_parallel_rank=0):
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
            async for _ in engine.generate(request_id="request-35",
                                           prompt=TEXT_PROMPT,
                                           sampling_params=sampling_params,
                                           data_parallel_rank=1):
                pass