test_async_llm_dp.py 5.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6

import asyncio
import os
from contextlib import ExitStack
7
from dataclasses import dataclass
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.config import VllmConfig
13
14
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
15
from vllm.platforms import current_platform
16
17
18
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
19
from vllm.v1.metrics.loggers import StatLoggerBase
20
from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats
21
22

DP_SIZE = int(os.getenv("DP_SIZE", 2))
23
24


25
async def generate(
26
27
28
29
30
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
31
32
    prompt_logprobs: int | None = None,
    data_parallel_rank: int | None = None,
33
) -> tuple[int, str]:
34
35
36
37
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

    count = 0
38
39
40
41
42
43
44
45
46
47
48
49
50
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0,
        prompt_logprobs=prompt_logprobs,
    )
    async for out in engine.generate(
        request_id=request_id,
        prompt=prompt,
        sampling_params=sampling_params,
        data_parallel_rank=data_parallel_rank,
    ):
51
52
53
54
55
56
        num_tokens = len(out.outputs[0].token_ids)
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens

57
        await asyncio.sleep(0.0)
58
59
60
61

    return count, request_id


62
63
64
65
66
67
68
@pytest.mark.parametrize(
    "model",
    [
        "ibm-research/PowerMoE-3b",
        "hmellor/tiny-random-LlamaForCausalLM",
    ],
)
69
@pytest.mark.parametrize(
Rui Qiao's avatar
Rui Qiao committed
70
71
72
73
74
75
76
    "output_kind",
    [
        RequestOutputKind.DELTA,
        RequestOutputKind.FINAL_ONLY,
    ],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
77
@pytest.mark.parametrize("async_scheduling", [True, False])
78
@pytest.mark.asyncio
79
async def test_load(
80
81
82
83
    model: str,
    output_kind: RequestOutputKind,
    data_parallel_backend: str,
    async_scheduling: bool,
84
):
85
86
87
    if async_scheduling and data_parallel_backend == "ray":
        # TODO(NickLucche) Re-enable when async scheduling is supported
        pytest.skip("Async scheduling is not supported with ray")
88
89
90
91
    elif data_parallel_backend == "ray" and current_platform.is_rocm():
        pytest.skip(
            "Ray as the distributed executor backend is not supported with ROCm."
        )
92
93
94
95
96
97
98
99
100
101
    stats_loggers = {}

    @dataclass
    class SimpleStatsLogger(StatLoggerBase):
        init_count: int = 0
        finished_req_count: int = 0

        def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
            stats_loggers[engine_index] = self

102
103
        def record(
            self,
104
105
106
            scheduler_stats: SchedulerStats | None,
            iteration_stats: IterationStats | None,
            mm_cache_stats: MultiModalCacheStats | None = None,
107
108
            engine_idx: int = 0,
        ):
109
            if iteration_stats:
110
                self.finished_req_count += len(iteration_stats.finished_requests)
111
112
113
114

        def log_engine_initialized(self):
            self.init_count += 1

115
116
117
    with ExitStack() as after:
        prompt = "This is a test of data parallel"

118
119
120
121
122
123
124
125
        engine_args = AsyncEngineArgs(
            model=model,
            enforce_eager=True,
            tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
            data_parallel_size=DP_SIZE,
            data_parallel_backend=data_parallel_backend,
            async_scheduling=async_scheduling,
        )
126
127
128
        engine = AsyncLLM.from_engine_args(
            engine_args, stat_loggers=[SimpleStatsLogger]
        )
129
130
131
132
133
134
135
136
137
138
139
140
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
141
142
143
144
145
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
146
147
            # Short sleep to ensure that requests are distributed.
            await asyncio.sleep(0.01)
148
        # Confirm that we got all the EXPECTED tokens from the requests.
149
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
150
151
152
153
154
155
        for task in pending:
            task.cancel()
        for task in done:
            num_generated_tokens, request_id = await task
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
156
157
                f"expected {NUM_EXPECTED_TOKENS}"
            )
158
159
160
161
162
163
164
165

        assert not engine.output_processor.has_unfinished_requests()

        # testing internals here which may break
        core_client: DPAsyncMPClient = engine.engine_core
        # the engines only synchronize stopping every N steps so
        # allow a small amount of time here.
        for _ in range(10):
166
            if not core_client.engines_running:
167
168
169
                break
            await asyncio.sleep(0.5)

170
        assert not core_client.engines_running
171
        assert not core_client.reqs_in_flight
172
173
174
175
176
177
178
179
180

        # Check that requests were distributed between the engines
        print(f"Stats loggers after test: {stats_loggers}")
        assert len(stats_loggers) == DP_SIZE
        assert stats_loggers[0].init_count == 1

        for sl in stats_loggers.values():
            slogger: SimpleStatsLogger = sl

181
182
183
            assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), (
                f"requests are imbalanced: {stats_loggers}"
            )