test_async_llm.py 9.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from typing import Optional
7
from unittest.mock import MagicMock
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.assets.image import ImageAsset
13
from vllm.config import VllmConfig
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.inputs import PromptType
16
from vllm.platforms import current_platform
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.v1.engine.async_llm import AsyncLLM
19
from vllm.v1.metrics.loggers import LoggingStatLogger
20
21
22
23
24

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
                                   enforce_eager=True,
                                   disable_log_requests=True)

VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
                                     enforce_eager=True,
                                     disable_log_requests=True)

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
    "<|im_start|>assistant\n")
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
    "multi_modal_data": {
        "image": ImageAsset("stop_sign").pil_image
    }
}
46
47


48
49
async def generate(engine: AsyncLLM,
                   request_id: str,
50
                   prompt: PromptType,
51
                   output_kind: RequestOutputKind,
52
                   max_tokens: int,
53
                   n: int = 1,
54
                   prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
55
56
57
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

58
    count = 0
59
    sampling_params = SamplingParams(max_tokens=max_tokens,
60
                                     ignore_eos=True,
61
                                     output_kind=output_kind,
62
63
64
                                     temperature=0.5,
                                     seed=33,
                                     n=n,
65
                                     prompt_logprobs=prompt_logprobs)
66
    async for out in engine.generate(request_id=request_id,
67
                                     prompt=prompt,
68
69
                                     sampling_params=sampling_params):

70
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
71
72
73
74
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
75
76
77
78
79
80

        await asyncio.sleep(0.)

    return count, request_id


81
82
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
83
@pytest.mark.parametrize("engine_args,prompt",
84
85
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
86
@pytest.mark.asyncio
87
88
89
async def test_load(monkeypatch: pytest.MonkeyPatch,
                    output_kind: RequestOutputKind,
                    engine_args: AsyncEngineArgs, prompt: PromptType):
90
91
92
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
93
    with monkeypatch.context() as m, ExitStack() as after:
94
95
        m.setenv("VLLM_USE_V1", "1")

96
        engine = AsyncLLM.from_engine_args(engine_args)
97
        after.callback(engine.shutdown)
98

99
        NUM_REQUESTS = 100
100
101
102
103
104
105
106
107
108
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
109
                    generate(engine, request_id, prompt, output_kind,
110
                             NUM_EXPECTED_TOKENS)))
111
112

        # Confirm that we got all the EXPECTED tokens from the requests.
113
114
115
116
117
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
118
            num_generated_tokens, request_id = await task
119
120
121
122
123
124
125
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()


126
127
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
128
@pytest.mark.parametrize("engine_args,prompt",
129
130
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
131
@pytest.mark.asyncio
132
133
async def test_abort(monkeypatch: pytest.MonkeyPatch,
                     output_kind: RequestOutputKind,
134
                     engine_args: AsyncEngineArgs, prompt: PromptType):
135

136
    with monkeypatch.context() as m, ExitStack() as after:
137
138
        m.setenv("VLLM_USE_V1", "1")

139
        engine = AsyncLLM.from_engine_args(engine_args)
140
        after.callback(engine.shutdown)
141
142
143

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
144
        NUM_EXPECTED_TOKENS_LONG = 50000
145
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
146
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
147
148
149
150

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
151
        tasks: list[asyncio.Task] = []
152
153
154
155
        for idx, request_id in enumerate(request_ids):
            max_tokens = NUM_EXPECTED_TOKENS_LONG if (
                idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
156
157
            tasks.append(
                asyncio.create_task(
158
                    generate(engine, request_id, prompt, output_kind,
159
                             max_tokens, n)))
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
175
176
177
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
178
                    f"{request_id} generated {num_generated_tokens} but "
179
                    f"expected {expected_tokens}")
180

181
        # Make sure all aborted requests were really aborted.
182
183
184
185
186
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
187
188
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
189
190
191
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
192
193
194


@pytest.mark.parametrize("n", [1, 3])
195
@pytest.mark.parametrize("engine_args,prompt",
196
197
198
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio
199
200
async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
                             engine_args: AsyncEngineArgs, prompt: PromptType):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(engine_args)
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=100,
                                         output_kind=RequestOutputKind.DELTA,
                                         temperature=1.0,
                                         seed=33,
                                         n=n)
        outputs = [
            out
            async for out in engine.generate(request_id="request-33",
                                             prompt=prompt,
                                             sampling_params=sampling_params)
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252


class MockLoggingStatLogger(LoggingStatLogger):

    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
    be used directly.
    """

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(
            TEXT_ENGINE_ARGS,
            stat_loggers=[MockLoggingStatLogger],
        )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        assert len(engine.stat_loggers) == 1
        assert len(engine.stat_loggers[0]) == 1
        engine.stat_loggers[0][0].log.assert_called_once()