test_context_parallel.py 7.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
import json
import os
from dataclasses import dataclass
14
from typing import Literal, NamedTuple
15
16

import pytest
17
import torch
18

19
from vllm.config.model import RunnerOption
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_context_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    dcp_size: int
34
    dcp_kv_cache_interleave_size: int
35
36
37
38
39
40
    eager_mode: bool
    chunked_prefill: bool


class CPTestOptions(NamedTuple):
    multi_node_only: bool
41
    load_format: str | None = None
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


@dataclass
class CPTestSettings:
    parallel_setups: list[ParallelSetup]
    distributed_backends: list[str]
    runner: RunnerOption
    test_options: CPTestOptions

    @staticmethod
    def detailed(
        *,
        tp_base: int = 4,
        pp_base: int = 1,
        dcp_base: int = 1,
57
        dcp_kv_cache_interleave_size: int = 1,
58
59
        multi_node_only: bool = False,
        runner: RunnerOption = "auto",
60
        load_format: str | None = None,
61
62
63
64
    ):
        parallel_setups = []
        for eager_mode_val in [False]:
            for pp_multiplier in [1]:
65
                for dcp_multiplier in [0.5, 1]:
66
67
                    for chunked_prefill_val in [True]:
                        parallel_setups.append(
68
69
70
71
                            ParallelSetup(
                                tp_size=tp_base,
                                pp_size=pp_multiplier * pp_base,
                                dcp_size=int(dcp_multiplier * tp_base),
72
                                dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
73
74
75
76
                                eager_mode=eager_mode_val,
                                chunked_prefill=chunked_prefill_val,
                            )
                        )
77
78
79
80
        return CPTestSettings(
            parallel_setups=parallel_setups,
            distributed_backends=["mp"],
            runner=runner,
81
82
83
            test_options=CPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
84
85
86
87
88
89
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
90
            for backend in self.distributed_backends:
91
92
93
94
95
96
97
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    self.runner,
                    opts,
                )
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


def _compare_cp_with_tp(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    runner: RunnerOption,
    test_options: CPTestOptions,
    num_gpus_available: int,
    *,
    method: Literal["generate"],
    is_multimodal: bool,
):
    (
        tp_size,
        pp_size,
        dcp_size,
115
        dcp_kv_cache_interleave_size,
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
149
150
151
152
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

    common_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
    if runner != "auto":
        common_args.extend(["--runner", runner])
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

    cp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--pipeline-parallel-size",
        str(pp_size),
        "--decode-context-parallel-size",
        str(dcp_size),
188
189
        "--dcp-kv-cache-interleave-size",
        str(dcp_kv_cache_interleave_size),
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        "--distributed-executor-backend",
        distributed_backend,
    ]

    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--pipeline-parallel-size",
        str(pp_size),
        "--distributed-executor-backend",
        distributed_backend,
    ]

204
205
206
207
208
209
210
    compare_two_settings(
        model_id,
        cp_args,
        tp_args,
        method=method,
        max_wait_seconds=720,
    )
211
212
213


CP_TEXT_GENERATION_MODELS = {
214
215
216
    "deepseek-ai/DeepSeek-V2-Lite-Chat": [
        CPTestSettings.detailed(),
        CPTestSettings.detailed(tp_base=2),
217
        CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
218
    ],
219
220
221
222
    "bigcode/gpt_bigcode-santacoder": [
        CPTestSettings.detailed(),
        CPTestSettings.detailed(tp_base=2),
    ],
223
224
225
226
227
228
}

CP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
    "deepseek-ai/DeepSeek-V2-Lite-Chat",
229
    "bigcode/gpt_bigcode-santacoder",
230
231
232
233
]


@pytest.mark.parametrize(
234
235
236
237
238
239
240
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "runner",
        "test_options",
    ),
241
    [
242
243
244
245
        params
        for model_id, settings in CP_TEXT_GENERATION_MODELS.items()
        for setting in settings
        for params in setting.iter_params(model_id)
246
247
248
249
250
251
252
253
254
255
256
257
        if model_id in CP_TEST_MODELS
    ],
)
@create_new_process_for_each_test()
def test_cp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    runner: RunnerOption,
    test_options: CPTestOptions,
    num_gpus_available,
):
258
259
260
261
262
263
264
265
266
267
268
    if (
        model_id == "deepseek-ai/DeepSeek-V2-Lite-Chat"
        and torch.cuda.get_device_capability() < (9, 0)
    ):
        pytest.skip(reason="MLA+DCP requires compute capability of 9.0 or higher")
    if (
        model_id == "bigcode/gpt_bigcode-santacoder"
        and torch.cuda.get_device_capability() != (9, 0)
    ):
        pytest.skip(reason="GQA+DCP currently requires compute capability of 9.0")

269
270
271
272
273
274
275
276
277
278
    _compare_cp_with_tp(
        model_id,
        parallel_setup,
        distributed_backend,
        runner,
        test_options,
        num_gpus_available,
        method="generate",
        is_multimodal=False,
    )