test_pipeline_parallel.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
9
import json
10
import os
11
from dataclasses import dataclass
12
from typing import Literal, NamedTuple, Optional
13

14
15
import pytest

16
from vllm.config import TaskOption
17
18
from vllm.logger import init_logger

19
from ..models.registry import HF_EXAMPLE_MODELS
20
from ..utils import compare_two_settings, fork_new_process_for_each_test
21

22
23
logger = init_logger("test_pipeline_parallel")

24
25
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

26

27
28
29
30
31
32
33
34
35
36
37
38
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    For PP, we fall back to V0 by default. This means
    that the TP baseline runs with V1 while the PP engine
    runs with V0. This gives divergent results with dummy
    weights. Once we enable V1 by default for PP, we can
    remove this.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


39
40
41
42
43
44
45
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


46
47
class PPTestOptions(NamedTuple):
    multi_node_only: bool
48
    load_format: Optional[str] = None
49
50


51
52
@dataclass
class PPTestSettings:
53
    parallel_setups: list[ParallelSetup]
54
55
56
57
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
58
    distributed_backends: list[str]
59
    # vllm major version: "0" for V0, "1" for V1
60
    vllm_major_versions: list[str]
61
    task: TaskOption
62
    test_options: PPTestOptions
63

64
65
66
67
68
69
70
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

71
72
73
74
75
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
76
        multi_node_only: bool = False,
77
        task: TaskOption = "auto",
78
        load_format: Optional[str] = None,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
103
104
105
            # only ray is supported for V1
            distributed_backends=["mp", "ray", "ray"],
            vllm_major_versions=["0", "0", "1"],
106
            task=task,
107
            test_options=PPTestOptions(multi_node_only=multi_node_only,
108
                                       load_format=load_format),
109
110
111
112
113
114
115
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
116
        task: TaskOption = "auto",
117
        multi_node_only: bool = False,
118
        load_format: Optional[str] = None,
119
120
121
122
123
124
125
126
127
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
128
            vllm_major_versions=["0"],
129
            task=task,
130
            test_options=PPTestOptions(multi_node_only=multi_node_only,
131
                                       load_format=load_format),
132
133
        )

134
    def iter_params(self, model_id: str):
135
136
        opts = self.test_options

137
        for parallel_setup in self.parallel_setups:
138
139
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
140
                yield (model_id, parallel_setup, backend, vllm_major_version,
141
                       self.task, opts)
142
143


144
145
146
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

147
# yapf: disable
148
149
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
150
151
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
152
153
154
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
155
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
156
157
158
159
    "THUDM/chatglm3-6b": PPTestSettings.fast(),
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
    "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
160
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
161
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
162
163
164
165
166
167
168
169
170
171
172
173
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
    "google/gemma-2b": PPTestSettings.fast(),
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
    "EleutherAI/pythia-12b": PPTestSettings.fast(),
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
174
    "internlm/internlm2-chat-7b": PPTestSettings.fast(),
175
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
176
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
177
    "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
178
179
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
180
181
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
182
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
183
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
184
185
186
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
187
    "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
188
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
189
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
190
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
191
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
192
    "microsoft/phi-2": PPTestSettings.fast(),
193
194
195
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"),  # noqa: E501
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
196
    "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
197
198
199
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
200
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
201
202
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
203
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
204
205
206
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
207
208
}

209
210
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
211
212
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
213
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
214
215
}

216
217
MULTIMODAL_MODELS = {
    # [Decoder-only]
218
219
220
    "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
221
222
    "THUDM/glm-4v-9b": PPTestSettings.fast(),
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
223
224
225
226
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
227
228
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
229
    "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
230
231
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
232
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
233
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
234
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
235
    # [Encoder-decoder]
236
    # TODO: Implement PP
237
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
238
239
240
}
# yapf: enable

241
# NOTE: You can update this on your local machine to run specific tests
242
TEST_MODELS = [
243
    # [LANGUAGE GENERATION]
244
    "microsoft/Phi-3.5-MoE-instruct",
245
    "meta-llama/Llama-3.2-1B-Instruct",
246
247
248
249
250
    "ibm/PowerLM-3b",
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
251
    "OpenGVLab/InternVL2-1B",
252
    "microsoft/Phi-3.5-vision-instruct",
253
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
254
255
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
256
257
258
]


259
def _compare_tp(
260
    model_id: str,
261
262
    parallel_setup: ParallelSetup,
    distributed_backend: str,
263
    vllm_major_version: str,
264
    task: TaskOption,
265
    test_options: PPTestOptions,
266
267
    num_gpus_available: int,
    *,
268
    method: Literal["generate", "encode"],
269
    is_multimodal: bool,
270
):
271
272
273
274
275
276
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
277
278
279
280
281
282
283
284
285
286
287
288
289

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
290
291
292
293
294
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
295
296
297
298
299
300
301
302
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")
303

304
305
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
306
    if VLLM_MULTI_NODE and distributed_backend == "mp":
307
308
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
309
310
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
311

312
    common_args = [
313
314
        # use half precision for speed and memory savings in CI environment
        "--dtype",
315
        "float16",
316
        "--max-model-len",
317
318
319
320
321
322
323
324
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
325
326
    if task != "auto":
        common_args.extend(["--task", task])
327
328
329
330
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
331
332
333
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
334
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
335

336
337
338
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
339
340
        # For V1, test Ray Compiled Graph for all the tests
        # For V0, test Ray Compiled Graph for a subset of the tests
341
        pp_env = {
342
            "VLLM_USE_V1": vllm_major_version,
343
344
345
346
347
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
348
        # terminate because of a Ray Compiled Graph issue.
349
350
351
352
        common_args.append("--disable-frontend-multiprocessing")
    else:
        pp_env = None

353
354
355
356
    tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }

357
358
    pp_args = [
        *common_args,
359
        "--pipeline-parallel-size",
360
        str(pp_size),
361
        "--tensor-parallel-size",
362
        str(tp_size),
363
        "--distributed-executor-backend",
364
        distributed_backend,
365
    ]
366
367
368
369
370
371
372

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
373
        *common_args,
374
        "--tensor-parallel-size",
375
        str(tp_size),
376
377
378
379
        "--distributed-executor-backend",
        "mp",
    ]

380
    try:
381
382
383
384
385
386
        compare_two_settings(model_id,
                             pp_args,
                             tp_args,
                             pp_env,
                             tp_env,
                             method=method)
387
    except Exception:
388
389
390
        testing_ray_compiled_graph = pp_env is not None
        if testing_ray_compiled_graph and vllm_major_version == "0":
            # Ray Compiled Graph tests are flaky for V0,
391
392
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
393
394
        else:
            raise
395
396
397


@pytest.mark.parametrize(
398
399
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
400
    [
401
402
        params for model_id, settings in TEXT_GENERATION_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
403
404
405
406
    ],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
407
    model_id: str,
408
409
    parallel_setup: ParallelSetup,
    distributed_backend: str,
410
    vllm_major_version: str,
411
    task: TaskOption,
412
    test_options: PPTestOptions,
413
414
    num_gpus_available,
):
415
    _compare_tp(model_id,
416
417
                parallel_setup,
                distributed_backend,
418
                vllm_major_version,
419
                task,
420
                test_options,
421
                num_gpus_available,
422
423
                method="generate",
                is_multimodal=False)
424
425
426


@pytest.mark.parametrize(
427
428
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
429
    [
430
431
        params for model_id, settings in EMBEDDING_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
432
433
434
435
    ],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
436
    model_id: str,
437
438
    parallel_setup: ParallelSetup,
    distributed_backend: str,
439
    vllm_major_version: str,
440
    task: TaskOption,
441
    test_options: PPTestOptions,
442
443
    num_gpus_available,
):
444
    _compare_tp(model_id,
445
446
                parallel_setup,
                distributed_backend,
447
                vllm_major_version,
448
                task,
449
                test_options,
450
                num_gpus_available,
451
452
                method="encode",
                is_multimodal=False)
453
454
455


@pytest.mark.parametrize(
456
457
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
458
    [
459
460
        params for model_id, settings in MULTIMODAL_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
461
462
463
464
    ],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
465
    model_id: str,
466
467
    parallel_setup: ParallelSetup,
    distributed_backend: str,
468
    vllm_major_version: str,
469
    task: TaskOption,
470
    test_options: PPTestOptions,
471
472
    num_gpus_available,
):
473
    _compare_tp(model_id,
474
475
                parallel_setup,
                distributed_backend,
476
                vllm_major_version,
477
                task,
478
                test_options,
479
                num_gpus_available,
480
481
                method="generate",
                is_multimodal=True)