test_sequence_parallel.py 10.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
14
15
16
17
import json
import os
from dataclasses import dataclass
from typing import Literal, NamedTuple, Optional

import pytest

18
from vllm.config import RunnerOption
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_sequence_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
    tp_size: int
31
    pp_size: int
32
    enable_fusion: bool
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    eager_mode: bool
    chunked_prefill: bool


class SPTestOptions(NamedTuple):
    multi_node_only: bool
    load_format: Optional[str] = None


@dataclass
class SPTestSettings:
    parallel_setups: list[ParallelSetup]
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
    distributed_backends: list[str]
    # vllm major version: "0" for V0, "1" for V1
    vllm_major_versions: list[str]
52
    runner: RunnerOption
53
54
55
56
57
58
59
    test_options: SPTestOptions

    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
60
61
                f"vllm_major_versions ({len(self.vllm_major_versions)})"
            )
62
63
64
65
66

    @staticmethod
    def detailed(
        *,
        tp_base: int = 2,
67
        pp_base: int = 1,
68
        multi_node_only: bool = False,
69
        runner: RunnerOption = "auto",
70
71
        load_format: Optional[str] = None,
    ):
72
73
74
75
76
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
77
78
79
80
81
82
83
84
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
85
        return SPTestSettings(
86
            parallel_setups=parallel_setups,
87
88
            distributed_backends=["mp", "ray"],
            vllm_major_versions=["1", "1"],
89
            runner=runner,
90
91
92
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
93
94
95
96
97
98
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 2,
99
        pp_base: int = 1,
100
        runner: RunnerOption = "auto",
101
102
103
        multi_node_only: bool = False,
        load_format: Optional[str] = None,
    ):
104
105
106
107
108
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
109
110
111
112
113
114
115
116
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
117
        return SPTestSettings(
118
119
120
            parallel_setups=parallel_setups,
            distributed_backends=["mp", "ray"],
            vllm_major_versions=["1", "1"],
121
            runner=runner,
122
123
124
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
125
126
127
128
129
130
131
        )

    @staticmethod
    def fp8_quant(
        *,
        tp_base: int = 2,
        pp_base: int = 1,
132
        runner: RunnerOption = "auto",
133
134
135
136
137
138
        multi_node_only: bool = False,
        load_format: Optional[str] = None,
    ):
        parallel_setups = []
        for fusion_val in [False, True]:
            parallel_setups.append(
139
140
141
142
143
144
145
146
                ParallelSetup(
                    tp_size=tp_base,
                    pp_size=pp_base,
                    enable_fusion=fusion_val,
                    eager_mode=True,
                    chunked_prefill=False,
                )
            )
147
148
        return SPTestSettings(
            parallel_setups=parallel_setups,
149
150
            distributed_backends=["mp", "ray"],
            vllm_major_versions=["1", "1"],
151
            runner=runner,
152
153
154
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
155
156
157
158
159
160
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
161
162
163
164
165
166
167
168
169
170
171
            for backend, vllm_major_version in zip(
                self.distributed_backends, self.vllm_major_versions
            ):
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    vllm_major_version,
                    self.runner,
                    opts,
                )
172
173
174
175
176
177
178


def _compare_sp(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    vllm_major_version: str,
179
    runner: RunnerOption,
180
181
182
183
184
185
186
187
    test_options: SPTestOptions,
    num_gpus_available: int,
    *,
    method: Literal["generate", "encode"],
    is_multimodal: bool,
):
    (
        tp_size,
188
        pp_size,
189
        enable_fusion,
190
191
192
193
194
195
196
197
198
199
200
201
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
202
    skip_tokenizer_init = model_info.skip_tokenizer_init
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
224
225
226
227
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

    common_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
244
245
    if runner != "auto":
        common_args.extend(["--runner", runner])
246
247
248
249
250
251
252
253
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
254
255
    if skip_tokenizer_init:
        common_args.append("--skip-tokenizer-init")
256
257

    compilation_config = {
258
259
260
261
262
263
264
        "level": 3,
        "custom_ops": ["+rms_norm"],
        "compile_sizes": [4, 8],
        "pass_config": {
            "enable_sequence_parallelism": True,
            "enable_fusion": enable_fusion,
            "enable_noop": True,
265
266
267
268
269
270
271
272
273
274
275
        },
    }

    tp_sp_env = tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }

    tp_sp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
276
277
        "--pipeline-parallel-size",
        str(pp_size),
278
279
280
        "--distributed-executor-backend",
        distributed_backend,
        "--compilation_config",
281
        json.dumps(compilation_config),
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    ]

    tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }
    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        "mp",
    ]

    try:
296
297
298
        compare_two_settings(
            model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method
        )
299
300
301
302
303
304
305
306
307
308
309
310
    except Exception:
        testing_ray_compiled_graph = tp_sp_env is not None
        if testing_ray_compiled_graph and vllm_major_version == "0":
            # Ray Compiled Graph tests are flaky for V0,
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
        else:
            raise


SP_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
311
    "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
312
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
313
314
315
316
317
318
}

SP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
    "meta-llama/Llama-3.2-1B-Instruct",
Huy Do's avatar
Huy Do committed
319
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
320
321
322
323
]


@pytest.mark.parametrize(
324
325
326
327
328
329
330
331
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "vllm_major_version",
        "runner",
        "test_options",
    ),
332
    [
333
334
        params
        for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
335
336
337
338
339
340
341
342
343
344
        for params in settings.iter_params(model_id)
        if model_id in SP_TEST_MODELS
    ],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
    vllm_major_version: str,
345
    runner: RunnerOption,
346
347
348
    test_options: SPTestOptions,
    num_gpus_available,
):
349
350
351
352
353
354
355
356
357
358
359
    _compare_sp(
        model_id,
        parallel_setup,
        distributed_backend,
        vllm_major_version,
        runner,
        test_options,
        num_gpus_available,
        method="generate",
        is_multimodal=False,
    )