test_sequence_parallel.py 9.05 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
import json
import os
from dataclasses import dataclass
14
from typing import Literal, NamedTuple
15
16
17

import pytest

18
from vllm.config.compilation import CompilationMode
19
from vllm.config.model import RunnerOption
20
21
22
23
24
25
26
27
28
29
30
31
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_sequence_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
    tp_size: int
32
    pp_size: int
33
    enable_fusion: bool
34
35
36
37
38
39
    eager_mode: bool
    chunked_prefill: bool


class SPTestOptions(NamedTuple):
    multi_node_only: bool
40
    load_format: str | None = None
41
42
43
44
45
46


@dataclass
class SPTestSettings:
    parallel_setups: list[ParallelSetup]
    distributed_backends: list[str]
47
    runner: RunnerOption
48
49
50
51
52
53
    test_options: SPTestOptions

    @staticmethod
    def detailed(
        *,
        tp_base: int = 2,
54
        pp_base: int = 1,
55
        multi_node_only: bool = False,
56
        runner: RunnerOption = "auto",
57
        load_format: str | None = None,
58
    ):
59
60
61
62
63
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
64
65
66
67
68
69
70
71
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
72
        return SPTestSettings(
73
            parallel_setups=parallel_setups,
74
            distributed_backends=["mp", "ray"],
75
            runner=runner,
76
77
78
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
79
80
81
82
83
84
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 2,
85
        pp_base: int = 1,
86
        runner: RunnerOption = "auto",
87
        multi_node_only: bool = False,
88
        load_format: str | None = None,
89
    ):
90
91
92
93
94
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
95
96
97
98
99
100
101
102
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
103
        return SPTestSettings(
104
105
            parallel_setups=parallel_setups,
            distributed_backends=["mp", "ray"],
106
            runner=runner,
107
108
109
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
110
111
112
113
114
115
116
        )

    @staticmethod
    def fp8_quant(
        *,
        tp_base: int = 2,
        pp_base: int = 1,
117
        runner: RunnerOption = "auto",
118
        multi_node_only: bool = False,
119
        load_format: str | None = None,
120
121
122
123
    ):
        parallel_setups = []
        for fusion_val in [False, True]:
            parallel_setups.append(
124
125
126
127
128
129
130
131
                ParallelSetup(
                    tp_size=tp_base,
                    pp_size=pp_base,
                    enable_fusion=fusion_val,
                    eager_mode=True,
                    chunked_prefill=False,
                )
            )
132
133
        return SPTestSettings(
            parallel_setups=parallel_setups,
134
            distributed_backends=["mp", "ray"],
135
            runner=runner,
136
137
138
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
139
140
141
142
143
144
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
145
            for backend in self.distributed_backends:
146
147
148
149
150
151
152
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    self.runner,
                    opts,
                )
153
154
155
156
157
158


def _compare_sp(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
159
    runner: RunnerOption,
160
161
162
163
164
165
166
167
    test_options: SPTestOptions,
    num_gpus_available: int,
    *,
    method: Literal["generate", "encode"],
    is_multimodal: bool,
):
    (
        tp_size,
168
        pp_size,
169
        enable_fusion,
170
171
172
173
174
175
176
177
178
179
180
181
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
182
    skip_tokenizer_init = model_info.skip_tokenizer_init
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
204
205
206
207
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

    common_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
224
225
    if runner != "auto":
        common_args.extend(["--runner", runner])
226
227
228
229
230
231
232
233
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
234
235
    if skip_tokenizer_init:
        common_args.append("--skip-tokenizer-init")
236
237

    compilation_config = {
238
        "mode": CompilationMode.VLLM_COMPILE,
239
240
241
242
243
244
        "custom_ops": ["+rms_norm"],
        "compile_sizes": [4, 8],
        "pass_config": {
            "enable_sequence_parallelism": True,
            "enable_fusion": enable_fusion,
            "enable_noop": True,
245
246
247
248
249
250
251
        },
    }

    tp_sp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
252
253
        "--pipeline-parallel-size",
        str(pp_size),
254
255
256
        "--distributed-executor-backend",
        distributed_backend,
        "--compilation_config",
257
        json.dumps(compilation_config),
258
259
260
261
262
263
264
265
266
267
    ]

    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        "mp",
    ]

268
    compare_two_settings(model_id, tp_sp_args, tp_args, method=method)
269
270
271
272


SP_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
273
    "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
274
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
275
276
277
278
279
280
}

SP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
    "meta-llama/Llama-3.2-1B-Instruct",
Huy Do's avatar
Huy Do committed
281
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
282
283
284
285
]


@pytest.mark.parametrize(
286
287
288
289
290
291
292
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "runner",
        "test_options",
    ),
293
    [
294
295
        params
        for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
296
297
298
299
300
301
302
303
304
        for params in settings.iter_params(model_id)
        if model_id in SP_TEST_MODELS
    ],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
305
    runner: RunnerOption,
306
307
308
    test_options: SPTestOptions,
    num_gpus_available,
):
309
310
311
312
313
314
315
316
317
318
    _compare_sp(
        model_id,
        parallel_setup,
        distributed_backend,
        runner,
        test_options,
        num_gpus_available,
        method="generate",
        is_multimodal=False,
    )