test_context_parallel.py 8.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
import json
import os
from dataclasses import dataclass
14
from typing import Literal, NamedTuple
15
16

import pytest
17
import torch
18

19
20
from tests.evals.gsm8k.gsm8k_eval import evaluate_gsm8k
from tests.utils import RemoteOpenAIServer, create_new_process_for_each_test
21
from vllm.config.model import RunnerOption
22
23
24
25
26
27
28
29
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS

logger = init_logger("test_context_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
CP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
    "deepseek-ai/DeepSeek-V2-Lite-Chat",
    "Qwen/Qwen2.5-1.5B-Instruct",
]

# GSM8K eval configuration
NUM_QUESTIONS = 256  # Fast eval for CI
NUM_SHOTS = 5  # Few-shot examples
# tp accuracy with 2% buffer
MIN_ACCURACY = {
    # .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
    "deepseek-ai/DeepSeek-V2-Lite-Chat": 0.64,
    # .buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml
    "Qwen/Qwen2.5-1.5B-Instruct": 0.52,
}

48
49
50
51
52

class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    dcp_size: int
53
    cp_kv_cache_interleave_size: int
54
55
56
57
58
59
    eager_mode: bool
    chunked_prefill: bool


class CPTestOptions(NamedTuple):
    multi_node_only: bool
60
    attn_backend: str | None = None
61
62
63
64
65
66
67
68
69
70
71
72
73
74


@dataclass
class CPTestSettings:
    parallel_setups: list[ParallelSetup]
    distributed_backends: list[str]
    runner: RunnerOption
    test_options: CPTestOptions

    @staticmethod
    def detailed(
        *,
        tp_base: int = 4,
        pp_base: int = 1,
75
        dcp_multipliers: list[float] | None = None,
76
        cp_kv_cache_interleave_size: int = 1,
77
78
        multi_node_only: bool = False,
        runner: RunnerOption = "auto",
79
        attn_backend: str | None = None,
80
81
    ):
        parallel_setups = []
82
83
84
85
        if dcp_multipliers is None:
            dcp_multipliers = [
                0.5,
            ]
86
87
        for eager_mode_val in [False]:
            for pp_multiplier in [1]:
88
                for dcp_multiplier in dcp_multipliers:
89
90
                    for chunked_prefill_val in [True]:
                        parallel_setups.append(
91
92
93
94
                            ParallelSetup(
                                tp_size=tp_base,
                                pp_size=pp_multiplier * pp_base,
                                dcp_size=int(dcp_multiplier * tp_base),
95
                                cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
96
97
98
99
                                eager_mode=eager_mode_val,
                                chunked_prefill=chunked_prefill_val,
                            )
                        )
100
101
102
103
        return CPTestSettings(
            parallel_setups=parallel_setups,
            distributed_backends=["mp"],
            runner=runner,
104
            test_options=CPTestOptions(
105
106
                multi_node_only=multi_node_only,
                attn_backend=attn_backend,
107
            ),
108
109
110
111
112
113
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
114
            for backend in self.distributed_backends:
115
116
117
118
119
120
121
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    self.runner,
                    opts,
                )
122
123


124
125
CP_TEXT_GENERATION_MODELS = {
    "deepseek-ai/DeepSeek-V2-Lite-Chat": [
126
        CPTestSettings.detailed(dcp_multipliers=[1]),
127
        CPTestSettings.detailed(
128
129
130
            dcp_multipliers=[0.5],
            cp_kv_cache_interleave_size=64,
            attn_backend="FLASHMLA",
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        ),
    ],
    "Qwen/Qwen2.5-1.5B-Instruct": [
        CPTestSettings.detailed(
            cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
        ),
        CPTestSettings.detailed(
            cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
        ),
    ],
}


def _test_cp_gsm8k(
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    runner: RunnerOption,
    test_options: CPTestOptions,
    num_gpus_available: int,
    *,
    method: Literal["generate"],
    is_multimodal: bool,
):
    (
        tp_size,
        pp_size,
        dcp_size,
159
        cp_kv_cache_interleave_size,
160
161
162
163
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

164
    multi_node_only, attn_backend = test_options
165
166
167
168
169
170
171
172

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides

173
    model_info.check_available_online(on_fail="skip")
174
175
176
177

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
178
179
180
181
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
182
183
184
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

185
    server_args = [
186
187
188
189
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
190
        "4096",
191
        "--max-num-seqs",
192
        "64",
193
194
    ]
    if chunked_prefill:
195
        server_args.append("--enable-chunked-prefill")
196
    if eager_mode:
197
        server_args.append("--enforce-eager")
198
    if runner != "auto":
199
        server_args.extend(["--runner", runner])
200
    if trust_remote_code:
201
        server_args.append("--trust-remote-code")
202
    if tokenizer_mode:
203
        server_args.extend(["--tokenizer-mode", tokenizer_mode])
204
    if hf_overrides:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        server_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

    server_args.extend(
        [
            "--tensor-parallel-size",
            str(tp_size),
            "--pipeline-parallel-size",
            str(pp_size),
            "--decode-context-parallel-size",
            str(dcp_size),
            "--dcp-kv-cache-interleave-size",
            str(cp_kv_cache_interleave_size),
            "--distributed-executor-backend",
            distributed_backend,
        ]
    )
221

222
    if attn_backend:
223
        server_args.append(f"--attention-backend={attn_backend}")
224

225
    with RemoteOpenAIServer(
226
        model_id,
227
        server_args,
228
        max_wait_seconds=720,
229
230
231
232
233
234
235
236
237
238
239
    ) as remote_server:
        host = f"http://{remote_server.host}"
        port = remote_server.port

        # Run GSM8K evaluation
        results = evaluate_gsm8k(
            num_questions=NUM_QUESTIONS,
            num_shots=NUM_SHOTS,
            host=host,
            port=port,
        )
240

241
242
243
244
245
246
        # Validate accuracy is reasonable
        accuracy = results["accuracy"]
        min_accuracy = MIN_ACCURACY[model_id]
        assert accuracy >= min_accuracy, (
            f"TP+DCP accuracy too low: {accuracy:.3f} < {min_accuracy:.3f}"
        )
247
248
249


@pytest.mark.parametrize(
250
251
252
253
254
255
256
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "runner",
        "test_options",
    ),
257
    [
258
259
260
261
        params
        for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
        for setting in settings
        for params in setting.iter_params(model_id)
262
263
264
265
266
267
268
269
270
271
272
273
        if model_id in CP_TEST_MODELS
    ],
)
@create_new_process_for_each_test()
def test_cp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    runner: RunnerOption,
    test_options: CPTestOptions,
    num_gpus_available,
):
274
275
276
277
278
279
    if (
        model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
        and torch.cuda.get_device_capability() < (9, 0)
    ):
        pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
    if (
280
        model_id == "Qwen/Qwen2.5-1.5B-Instruct"
281
282
283
284
        and torch.cuda.get_device_capability() != (9, 0)
    ):
        pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")

285
    _test_cp_gsm8k(
286
287
288
289
290
291
292
293
294
        model_id,
        parallel_setup,
        distributed_backend,
        runner,
        test_options,
        num_gpus_available,
        method="generate",
        is_multimodal=False,
    )