test_sequence_parallel.py 8.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10

11
12
13
import json
import os
from dataclasses import dataclass
14
from typing import Literal, NamedTuple
15
16
17

import pytest

18
from vllm.config.model import RunnerOption
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.logger import init_logger

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

logger = init_logger("test_sequence_parallel")

VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"


class ParallelSetup(NamedTuple):
    tp_size: int
31
    pp_size: int
32
    enable_fusion: bool
33
34
35
36
37
38
    eager_mode: bool
    chunked_prefill: bool


class SPTestOptions(NamedTuple):
    multi_node_only: bool
39
    load_format: str | None = None
40
41
42
43
44
45


@dataclass
class SPTestSettings:
    parallel_setups: list[ParallelSetup]
    distributed_backends: list[str]
46
    runner: RunnerOption
47
48
49
50
51
52
    test_options: SPTestOptions

    @staticmethod
    def detailed(
        *,
        tp_base: int = 2,
53
        pp_base: int = 1,
54
        multi_node_only: bool = False,
55
        runner: RunnerOption = "auto",
56
        load_format: str | None = None,
57
    ):
58
59
60
61
62
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
63
64
65
66
67
68
69
70
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
71
        return SPTestSettings(
72
            parallel_setups=parallel_setups,
73
            distributed_backends=["mp", "ray"],
74
            runner=runner,
75
76
77
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
78
79
80
81
82
83
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 2,
84
        pp_base: int = 1,
85
        runner: RunnerOption = "auto",
86
        multi_node_only: bool = False,
87
        load_format: str | None = None,
88
    ):
89
90
91
92
93
        parallel_setups = []
        for eager_mode_val in [False, True]:
            for pp_multiplier in [1, 2]:
                for chunked_prefill_val in [False, True]:
                    parallel_setups.append(
94
95
96
97
98
99
100
101
                        ParallelSetup(
                            tp_size=tp_base,
                            pp_size=pp_multiplier * pp_base,
                            enable_fusion=False,
                            eager_mode=eager_mode_val,
                            chunked_prefill=chunked_prefill_val,
                        )
                    )
102
        return SPTestSettings(
103
104
            parallel_setups=parallel_setups,
            distributed_backends=["mp", "ray"],
105
            runner=runner,
106
107
108
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
109
110
111
112
113
114
115
        )

    @staticmethod
    def fp8_quant(
        *,
        tp_base: int = 2,
        pp_base: int = 1,
116
        runner: RunnerOption = "auto",
117
        multi_node_only: bool = False,
118
        load_format: str | None = None,
119
120
121
122
    ):
        parallel_setups = []
        for fusion_val in [False, True]:
            parallel_setups.append(
123
124
125
126
127
128
129
130
                ParallelSetup(
                    tp_size=tp_base,
                    pp_size=pp_base,
                    enable_fusion=fusion_val,
                    eager_mode=True,
                    chunked_prefill=False,
                )
            )
131
132
        return SPTestSettings(
            parallel_setups=parallel_setups,
133
            distributed_backends=["mp", "ray"],
134
            runner=runner,
135
136
137
            test_options=SPTestOptions(
                multi_node_only=multi_node_only, load_format=load_format
            ),
138
139
140
141
142
143
        )

    def iter_params(self, model_id: str):
        opts = self.test_options

        for parallel_setup in self.parallel_setups:
144
            for backend in self.distributed_backends:
145
146
147
148
149
150
151
                yield (
                    model_id,
                    parallel_setup,
                    backend,
                    self.runner,
                    opts,
                )
152
153
154
155
156
157


def _compare_sp(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
158
    runner: RunnerOption,
159
160
161
162
163
164
165
166
    test_options: SPTestOptions,
    num_gpus_available: int,
    *,
    method: Literal["generate", "encode"],
    is_multimodal: bool,
):
    (
        tp_size,
167
        pp_size,
168
        enable_fusion,
169
170
171
172
173
174
175
176
177
178
179
180
        eager_mode,
        chunked_prefill,
    ) = parallel_setup

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
181
    skip_tokenizer_init = model_info.skip_tokenizer_init
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")

    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
    if VLLM_MULTI_NODE and distributed_backend == "mp":
203
204
205
206
        pytest.skip(
            "Skipping multi-node pipeline parallel test for "
            "multiprocessing distributed backend"
        )
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")

    common_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
223
224
    if runner != "auto":
        common_args.extend(["--runner", runner])
225
226
227
228
229
230
231
232
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
233
234
    if skip_tokenizer_init:
        common_args.append("--skip-tokenizer-init")
235
236

    compilation_config = {
237
238
239
240
241
242
243
        "level": 3,
        "custom_ops": ["+rms_norm"],
        "compile_sizes": [4, 8],
        "pass_config": {
            "enable_sequence_parallelism": True,
            "enable_fusion": enable_fusion,
            "enable_noop": True,
244
245
246
247
248
249
250
        },
    }

    tp_sp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
251
252
        "--pipeline-parallel-size",
        str(pp_size),
253
254
255
        "--distributed-executor-backend",
        distributed_backend,
        "--compilation_config",
256
        json.dumps(compilation_config),
257
258
259
260
261
262
263
264
265
266
    ]

    tp_args = [
        *common_args,
        "--tensor-parallel-size",
        str(tp_size),
        "--distributed-executor-backend",
        "mp",
    ]

267
    compare_two_settings(model_id, tp_sp_args, tp_args, method=method)
268
269
270
271


SP_TEXT_GENERATION_MODELS = {
    # [Decoder-only]
272
    "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(),
273
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8": SPTestSettings.fp8_quant(),
274
275
276
277
278
279
}

SP_TEST_MODELS = [
    # TODO support other models
    # [LANGUAGE GENERATION]
    "meta-llama/Llama-3.2-1B-Instruct",
Huy Do's avatar
Huy Do committed
280
    "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
281
282
283
284
]


@pytest.mark.parametrize(
285
286
287
288
289
290
291
    (
        "model_id",
        "parallel_setup",
        "distributed_backend",
        "runner",
        "test_options",
    ),
292
    [
293
294
        params
        for model_id, settings in SP_TEXT_GENERATION_MODELS.items()
295
296
297
298
299
300
301
302
303
        for params in settings.iter_params(model_id)
        if model_id in SP_TEST_MODELS
    ],
)
@create_new_process_for_each_test()
def test_tp_sp_generation(
    model_id: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
304
    runner: RunnerOption,
305
306
307
    test_options: SPTestOptions,
    num_gpus_available,
):
308
309
310
311
312
313
314
315
316
317
    _compare_sp(
        model_id,
        parallel_setup,
        distributed_backend,
        runner,
        test_options,
        num_gpus_available,
        method="generate",
        is_multimodal=False,
    )