test_async_llm.py 6.86 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import asyncio
4
from contextlib import ExitStack
5
from typing import List, Optional, Tuple
6
7
8

import pytest

9
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
10
11
12
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
13
from vllm.sampling_params import RequestOutputKind
14
15
16
17
18
19
from vllm.v1.engine.async_llm import AsyncLLM

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

20
ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B-Instruct",
21
                              enforce_eager=True,
22
23
24
                              disable_log_requests=True)


25
26
async def generate(engine: AsyncLLM,
                   request_id: str,
27
                   output_kind: RequestOutputKind,
28
29
30
31
32
                   max_tokens: int,
                   prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

33
    count = 0
34
35
    sampling_params = SamplingParams(max_tokens=max_tokens,
                                     output_kind=output_kind,
36
37
                                     temperature=0,
                                     prompt_logprobs=prompt_logprobs)
38
39
40
41
42
43
44
45
46
    async for out in engine.generate(request_id=request_id,
                                     prompt="Hello my name is Robert and",
                                     sampling_params=sampling_params):

        num_tokens = len(out.outputs[0].token_ids)
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
47
48
49
50
51
52

        await asyncio.sleep(0.)

    return count, request_id


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_async_llm_refuses_prompt_logprobs_with_apc(
        monkeypatch, output_kind: RequestOutputKind):
    """Test passes if AsyncLLM raises an exception when it is configured
    for automatic prefix caching and it receives a request with
    prompt_logprobs enabled, which is incompatible."""
    # TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
    # better way to test V1 so that in the future when we switch, we don't
    # have to change all the tests.
    monkeypatch.setenv("VLLM_USE_V1", "1")
    # Create AsyncLLM engine with APC
    apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
                                      enable_prefix_caching=True,
                                      gpu_memory_utilization=0.8,
                                      disable_log_requests=True)
    engine = AsyncLLM.from_engine_args(apc_engine_args)
    try:
        with pytest.raises(ValueError) as excinfo:
            # Issue a request with prompt logprobs enabled, which should fail
            await asyncio.create_task(
                generate(engine,
                         "request-0",
                         output_kind,
                         10,
                         prompt_logprobs=5))
        # Validate exception string is correct
        assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
    finally:
        # Shut down engine
        engine.shutdown()


87
88
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
89
@pytest.mark.asyncio
90
async def test_load(monkeypatch, output_kind: RequestOutputKind):
91
92
93
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
94
    with monkeypatch.context() as m, ExitStack() as after:
95
96
97
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
98
        after.callback(engine.shutdown)
99
100
101
102
103
104
105
106
107
108
109

        NUM_REQUESTS = 10000
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
110
111
                    generate(engine, request_id, output_kind,
                             NUM_EXPECTED_TOKENS)))
112
113

        # Confirm that we got all the EXPECTED tokens from the requests.
114
115
116
117
118
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
119
            num_generated_tokens, request_id = await task
120
121
122
123
124
125
126
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()


127
128
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
129
@pytest.mark.asyncio
130
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
131

132
    with monkeypatch.context() as m, ExitStack() as after:
133
134
135
        m.setenv("VLLM_USE_V1", "1")

        engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
136
        after.callback(engine.shutdown)
137
138
139
140
141
142
143
144
145
146
147
148

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: List[asyncio.Task] = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
149
150
                    generate(engine, request_id, output_kind,
                             NUM_EXPECTED_TOKENS)))
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
                assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                    f"{request_id} generated {num_generated_tokens} but "
                    f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
175
            generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
176
177
178
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()