test_async_llm.py 8.72 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from contextlib import ExitStack
5
from typing import Optional
6
7
8
9

import pytest

from vllm import SamplingParams
10
from vllm.assets.image import ImageAsset
11
from vllm.engine.arg_utils import AsyncEngineArgs
12
from vllm.inputs import PromptType
13
from vllm.platforms import current_platform
14
from vllm.sampling_params import RequestOutputKind
15
16
17
18
19
20
from vllm.v1.engine.async_llm import AsyncLLM

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
TEXT_ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
                                   enforce_eager=True,
                                   disable_log_requests=True)

VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
                                     enforce_eager=True,
                                     disable_log_requests=True)

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
    "<|im_start|>assistant\n")
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
    "multi_modal_data": {
        "image": ImageAsset("stop_sign").pil_image
    }
}
42
43


44
45
async def generate(engine: AsyncLLM,
                   request_id: str,
46
                   prompt: PromptType,
47
                   output_kind: RequestOutputKind,
48
                   max_tokens: int,
49
                   n: int = 1,
50
                   prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
51
52
53
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

54
    count = 0
55
    sampling_params = SamplingParams(max_tokens=max_tokens,
56
                                     ignore_eos=True,
57
                                     output_kind=output_kind,
58
59
60
                                     temperature=0.5,
                                     seed=33,
                                     n=n,
61
                                     prompt_logprobs=prompt_logprobs)
62
    async for out in engine.generate(request_id=request_id,
63
                                     prompt=prompt,
64
65
                                     sampling_params=sampling_params):

66
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
67
68
69
70
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
71
72
73
74
75
76

        await asyncio.sleep(0.)

    return count, request_id


77
78
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
79
80
81
@pytest.mark.parametrize("engine_args_and_prompt",
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
82
@pytest.mark.asyncio
83
async def test_load(monkeypatch, output_kind: RequestOutputKind,
84
                    engine_args_and_prompt: tuple[AsyncEngineArgs,
85
                                                  PromptType]):
86
87
88
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
89
    with monkeypatch.context() as m, ExitStack() as after:
90
        m.setenv("VLLM_USE_V1", "1")
91
        engine_args, prompt = engine_args_and_prompt
92

93
        engine = AsyncLLM.from_engine_args(engine_args)
94
        after.callback(engine.shutdown)
95

96
        NUM_REQUESTS = 100
97
98
99
100
101
102
103
104
105
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
106
                    generate(engine, request_id, prompt, output_kind,
107
                             NUM_EXPECTED_TOKENS)))
108
109

        # Confirm that we got all the EXPECTED tokens from the requests.
110
111
112
113
114
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
115
            num_generated_tokens, request_id = await task
116
117
118
119
120
121
122
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()


123
124
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
125
126
127
@pytest.mark.parametrize("engine_args_and_prompt",
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
128
@pytest.mark.asyncio
129
async def test_abort(monkeypatch, output_kind: RequestOutputKind,
130
                     engine_args_and_prompt: tuple[AsyncEngineArgs,
131
                                                   PromptType]):
132

133
    with monkeypatch.context() as m, ExitStack() as after:
134
        m.setenv("VLLM_USE_V1", "1")
135
        engine_args, prompt = engine_args_and_prompt
136

137
        engine = AsyncLLM.from_engine_args(engine_args)
138
        after.callback(engine.shutdown)
139
140
141

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
142
        NUM_EXPECTED_TOKENS_LONG = 50000
143
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
144
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
145
146
147
148

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
149
        tasks: list[asyncio.Task] = []
150
151
152
153
        for idx, request_id in enumerate(request_ids):
            max_tokens = NUM_EXPECTED_TOKENS_LONG if (
                idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
154
155
            tasks.append(
                asyncio.create_task(
156
                    generate(engine, request_id, prompt, output_kind,
157
                             max_tokens, n)))
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
173
174
175
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
176
                    f"{request_id} generated {num_generated_tokens} but "
177
                    f"expected {expected_tokens}")
178

179
        # Make sure all aborted requests were really aborted.
180
181
182
183
184
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
185
186
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
187
188
189
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222


@pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args_and_prompt",
                         [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
                          (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio
async def test_finished_flag(monkeypatch, n: int,
                             engine_args_and_prompt: tuple[AsyncEngineArgs,
                                                           PromptType]):

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")
        engine_args, prompt = engine_args_and_prompt

        engine = AsyncLLM.from_engine_args(engine_args)
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=100,
                                         output_kind=RequestOutputKind.DELTA,
                                         temperature=1.0,
                                         seed=33,
                                         n=n)
        outputs = [
            out
            async for out in engine.generate(request_id="request-33",
                                             prompt=prompt,
                                             sampling_params=sampling_params)
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished