test_async_llm_dp.py 4.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import asyncio
import os
from contextlib import ExitStack
from typing import Optional

import pytest

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient

engine_args = AsyncEngineArgs(
    model="ibm-research/PowerMoE-3b",
    enforce_eager=True,
    disable_log_requests=True,
    tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
    data_parallel_size=int(os.getenv("DP_SIZE", 2)),
)

if not current_platform.supports_v1(engine_args.create_model_config()):
    pytest.skip(reason="Requires V1-supporting platform.",
                allow_module_level=True)


32
33
34
35
36
37
38
39
async def generate(
        engine: AsyncLLM,
        request_id: str,
        prompt: PromptType,
        output_kind: RequestOutputKind,
        max_tokens: int,
        prompt_logprobs: Optional[int] = None,
        data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
40
41
42
43
44
45
46
47
48
49
50
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

    count = 0
    sampling_params = SamplingParams(max_tokens=max_tokens,
                                     ignore_eos=True,
                                     output_kind=output_kind,
                                     temperature=0,
                                     prompt_logprobs=prompt_logprobs)
    async for out in engine.generate(request_id=request_id,
                                     prompt=prompt,
51
52
                                     sampling_params=sampling_params,
                                     data_parallel_rank=data_parallel_rank):
53
54
55
56
57
58
59
60
61
62
63
64
65

        num_tokens = len(out.outputs[0].token_ids)
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens

        await asyncio.sleep(0.)

    return count, request_id


@pytest.mark.parametrize(
Rui Qiao's avatar
Rui Qiao committed
66
67
68
69
70
71
72
    "output_kind",
    [
        RequestOutputKind.DELTA,
        RequestOutputKind.FINAL_ONLY,
    ],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
73
@pytest.mark.asyncio
Rui Qiao's avatar
Rui Qiao committed
74
75
async def test_load(output_kind: RequestOutputKind,
                    data_parallel_backend: str):
76
77
78
79
80

    with ExitStack() as after:

        prompt = "This is a test of data parallel"

Rui Qiao's avatar
Rui Qiao committed
81
        engine_args.data_parallel_backend = data_parallel_backend
82
83
84
85
86
87
88
89
90
91
92
93
94
        engine = AsyncLLM.from_engine_args(engine_args)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
95
96
97
98
99
100
                    generate(engine,
                             request_id,
                             prompt,
                             output_kind,
                             NUM_EXPECTED_TOKENS,
                             data_parallel_rank=0)))
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        # Confirm that we got all the EXPECTED tokens from the requests.
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
            num_generated_tokens, request_id = await task
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()

        # testing internals here which may break
        core_client: DPAsyncMPClient = engine.engine_core
        # the engines only synchronize stopping every N steps so
        # allow a small amount of time here.
        for _ in range(10):
119
            if not core_client.engines_running:
120
121
122
                break
            await asyncio.sleep(0.5)

123
        assert not core_client.engines_running
124
        assert not core_client.reqs_in_flight