test_sequence_parallel.py 10.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
import json
import os
from dataclasses import dataclass
14
from typing import Literal, NamedTuple
15
16
17

import pytest

18
from vllm.config.compilation import CompilationMode
19
from vllm.config.model import RunnerOption
20
from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22
from vllm.utils.torch_utils import is_torch_equal_or_newer
23
24
25
26
27
28
29
30
31
32
33

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_sequence_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
    tp_size: int
34
    pp_size: int
35
36
    fuse_norm_quant: bool
    fuse_act_quant: bool
37
38
39
40
41
42
    eager_mode: bool
    chunked_prefill: bool


class SPTestOptions(NamedTuple):
    multi_node_only: bool
43
    load_format: str | None = None
44
45
46
47
48
49


@dataclass
class SPTestSettings:
    parallel_setups: list[ParallelSetup]
    distributed_backends: list[str]
50
    runner: RunnerOption
51
52
53
54
55
56
    test_options: SPTestOptions

    @staticmethod
    def detailed(
        *,
        tp_base: int = 2,
57
        pp_base: int = 1,
58
        multi_node_only: bool = False,
59
        runner: RunnerOption = "auto",
60
        load_format: str | None = None,
61
    ):
62
63
64
65
66
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
67
68
69
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
70
71
                            fuse_norm_quant=False,
                            fuse_act_quant=False,
72
73
74
75
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
76
        return SPTestSettings(
77
            parallel_setups=parallel_setups,
78
            distributed_backends=["mp", "ray"],
79
            runner=runner,
80
81
82
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
83
84
85
86
87
88
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 2,
89
        pp_base: int = 1,
90
        runner: RunnerOption = "auto",
91
        multi_node_only: bool = False,
92
        load_format: str | None = None,
93
    ):
94
95
96
97
98
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
99
100
101
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
102
103
                            fuse_norm_quant=False,
                            fuse_act_quant=False,
104
105
106
107
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
108
        return SPTestSettings(
109
110
            parallel_setups=parallel_setups,
            distributed_backends=["mp", "ray"],
111
            runner=runner,
112
113
114
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
115
116
117
118
119
120
121
        )

    @staticmethod
    def fp8_quant(
        *,
        tp_base: int = 2,
        pp_base: int = 1,
122
        runner: RunnerOption = "auto",
123
        multi_node_only: bool = False,
124
        load_format: str | None = None,
125
126
127
128
    ):
        parallel_setups = []
        for fusion_val in [False, True]:
            parallel_setups.append(
129
130
131
                ParallelSetup(
                    tp_size=tp_base,
                    pp_size=pp_base,
132
133
                    fuse_norm_quant=fusion_val,
                    fuse_act_quant=fusion_val,
134
135
136
137
                    eager_mode=True,
                    chunked_prefill=False,
                )
            )
138
139
        return SPTestSettings(
            parallel_setups=parallel_setups,
140
            distributed_backends=["mp", "ray"],
141
            runner=runner,
142
143
144
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
145
146
147
148
149
150
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
151
            for backend in self.distributed_backends:
152
153
154
155
156
157
158
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    self.runner,
                    opts,
                )
159
160
161
162
163
164


def _compare_sp(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
165
    runner: RunnerOption,
166
167
    test_options: SPTestOptions,
    num_gpus_available: int,
168
    use_inductor_graph_partition: bool,
169
    fuse_gemm_comms: bool,
170
171
172
173
174
175
    *,
    method: Literal["generate", "encode"],
    is_multimodal: bool,
):
    (
        tp_size,
176
        pp_size,
177
178
        fuse_norm_quant,
        fuse_act_quant,
179
180
181
182
183
184
185
186
187
188
189
190
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
191
    require_embed_inputs = model_info.require_embed_inputs
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
213
214
215
216
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

    common_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
233
234
    if runner != "auto":
        common_args.extend(["--runner", runner])
235
236
237
238
239
240
241
242
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
243
244
245
246
247
248
249
250
    if require_embed_inputs:
        common_args.extend(
            [
                "--skip-tokenizer-init",
                "--enable-prompt-embeds",
                "--enable-mm-embeds",
            ]
        )
251
252

    compilation_config = {
253
        "mode": CompilationMode.VLLM_COMPILE,
254
255
        "compile_sizes": [4, 8],
        "pass_config": {
256
257
258
259
260
            "enable_sp": True,
            "fuse_gemm_comms": fuse_gemm_comms,
            "fuse_norm_quant": fuse_norm_quant,
            "fuse_act_quant": fuse_act_quant,
            "eliminate_noops": True,
261
        },
262
        "use_inductor_graph_partition": use_inductor_graph_partition,
263
264
265
266
267
268
    }

    tp_sp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
269
270
        "--pipeline-parallel-size",
        str(pp_size),
271
272
273
        "--distributed-executor-backend",
        distributed_backend,
        "--compilation_config",
274
        json.dumps(compilation_config),
275
276
277
278
279
280
281
282
283
284
    ]

    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        "mp",
    ]

285
    compare_two_settings(model_id, tp_sp_args, tp_args, method=method)
286
287
288
289


SP_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
290
    "hmellor/tiny-random-LlamaForCausalLM": SPTestSettings.fast(),
291
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
292
293
294
295
296
}

SP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
297
    "hmellor/tiny-random-LlamaForCausalLM",
Huy Do's avatar
Huy Do committed
298
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
299
300
301
302
]


@pytest.mark.parametrize(
303
304
305
306
307
308
309
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "runner",
        "test_options",
    ),
310
    [
311
312
        params
        for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
313
314
315
316
        for params in settings.iter_params(model_id)
        if model_id in SP_TEST_MODELS
    ],
)
317
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
318
@pytest.mark.parametrize("fuse_gemm_comms", [False])  # TODO: enable async TP
319
320
321
322
323
@create_new_process_for_each_test()
def test_tp_sp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
324
    runner: RunnerOption,
325
326
    test_options: SPTestOptions,
    num_gpus_available,
327
    use_inductor_graph_partition: bool,
328
    fuse_gemm_comms: bool,
329
):
330
331
332
    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

333
334
335
336
    # Skip FP8 SP-only test on sm89 (compute capability 8.9)
    if (
        "fp8" in model_id.lower()
        and current_platform.get_device_capability() < (9, 0)
337
        and (not fuse_gemm_comms)
338
339
340
    ):
        pytest.skip("FP8 reduction support begins with sm90 capable devices.")

341
342
343
344
345
346
347
    _compare_sp(
        model_id,
        parallel_setup,
        distributed_backend,
        runner,
        test_options,
        num_gpus_available,
348
        use_inductor_graph_partition,
349
        fuse_gemm_comms=fuse_gemm_comms,
350
351
352
        method="generate",
        is_multimodal=False,
    )