test_pipeline_parallel.py 17.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10
import json
11
import os
12
from dataclasses import dataclass
13
from typing import Literal, NamedTuple, Optional
14

15
16
import pytest

17
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
18
from vllm.logger import init_logger
19
from vllm.transformers_utils.config import get_config
20

21
from ..models.registry import HF_EXAMPLE_MODELS
22
from ..utils import compare_two_settings, create_new_process_for_each_test
23

24
25
logger = init_logger("test_pipeline_parallel")

26
27
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

28

29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    For PP, we fall back to V0 by default. This means
    that the TP baseline runs with V1 while the PP engine
    runs with V0. This gives divergent results with dummy
    weights. Once we enable V1 by default for PP, we can
    remove this.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


41
42
43
44
45
46
47
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


48
49
class PPTestOptions(NamedTuple):
    multi_node_only: bool
50
    load_format: Optional[str] = None
51
52


53
54
@dataclass
class PPTestSettings:
55
    parallel_setups: list[ParallelSetup]
56
57
58
59
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
60
    distributed_backends: list[str]
61
    # vllm major version: "0" for V0, "1" for V1
62
    vllm_major_versions: list[str]
63
    runner: RunnerOption
64
    test_options: PPTestOptions
65

66
67
68
69
70
71
72
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

73
74
75
76
77
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
78
        multi_node_only: bool = False,
79
        runner: RunnerOption = "auto",
80
        load_format: Optional[str] = None,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
105
106
            distributed_backends=["mp", "mp", "ray", "ray"],
            vllm_major_versions=["0", "1", "0", "1"],
107
            runner=runner,
108
            test_options=PPTestOptions(multi_node_only=multi_node_only,
109
                                       load_format=load_format),
110
111
112
113
114
115
116
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
117
        runner: RunnerOption = "auto",
118
        multi_node_only: bool = False,
119
        load_format: Optional[str] = None,
120
121
122
123
124
125
126
127
128
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
129
            vllm_major_versions=["0"],
130
            runner=runner,
131
            test_options=PPTestOptions(multi_node_only=multi_node_only,
132
                                       load_format=load_format),
133
134
        )

135
    def iter_params(self, model_id: str):
136
137
        opts = self.test_options

138
        for parallel_setup in self.parallel_setups:
139
140
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
141
                yield (model_id, parallel_setup, backend, vllm_major_version,
142
                       self.runner, opts)
143
144


145
146
147
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

148
# yapf: disable
149
150
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
151
152
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
153
154
155
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
156
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
157
    "zai-org/chatglm3-6b": PPTestSettings.fast(),
158
159
160
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
    "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
161
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
162
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(tp_base=2),
163
164
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
165
    "google/gemma-1.1-2b-it": PPTestSettings.fast(),
166
167
168
169
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
170
    "EleutherAI/pythia-1.4b": PPTestSettings.fast(),
171
172
173
174
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
175
    "internlm/internlm2-chat-7b": PPTestSettings.fast(),
176
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
177
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
178
    "pfnet/plamo-2-1b": PPTestSettings.fast(),
179
    "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
180
    # Tests TransformersForCausalLM
181
    "hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
182
183
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
184
185
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
186
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
187
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
188
189
190
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
191
    "allenai/OLMo-2-0425-1B": PPTestSettings.fast(),
192
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
193
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
194
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
195
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
196
    "microsoft/phi-2": PPTestSettings.fast(),
197
198
199
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"),  # noqa: E501
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
200
    "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(),
201
202
203
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
204
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
205
206
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
207
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
208
209
210
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
211
212
}

213
214
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
215
216
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
217
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
218
        load_format="dummy", runner="pooling"
219
    ),
220
221
}

222
223
MULTIMODAL_MODELS = {
    # [Decoder-only]
224
    "Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
225
226
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
227
    "zai-org/glm-4v-9b": PPTestSettings.fast(),
228
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
229
230
231
232
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
233
234
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
235
    "AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
236
    "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
237
238
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
239
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
240
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
241
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
242
    # [Encoder-decoder]
243
    # TODO: Implement PP
244
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
245
246
247
}
# yapf: enable

248
# NOTE: You can update this on your local machine to run specific tests
249
TEST_MODELS = [
250
    # [LANGUAGE GENERATION]
251
    "microsoft/Phi-3.5-MoE-instruct",
252
    "meta-llama/Llama-3.2-1B-Instruct",
253
    "hmellor/Ilama-3.2-1B",
254
    "ibm/PowerLM-3b",
255
    "deepseek-ai/DeepSeek-V2-Lite-Chat",
256
257
258
259
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
260
    "OpenGVLab/InternVL2-1B",
261
    "microsoft/Phi-3.5-vision-instruct",
262
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
263
264
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
265
266
267
]


268
def _compare_tp(
269
    model_id: str,
270
271
    parallel_setup: ParallelSetup,
    distributed_backend: str,
272
    vllm_major_version: str,
273
    runner: RunnerOption,
274
    test_options: PPTestOptions,
275
276
    num_gpus_available: int,
    *,
277
    method: Literal["generate", "encode"],
278
    is_multimodal: bool,
279
):
280
281
282
283
284
285
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
286
287
288
289
290
291
292
293
294

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
295
296
297
298
299
    hf_config = get_config(model_id, trust_remote_code)

    dtype = "float16"
    if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
        dtype = "bfloat16"
300
301
302
303

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
304
305
306
307
308
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
309
310
311
312
313
314
315
316
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")
317

318
319
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
320
    if VLLM_MULTI_NODE and distributed_backend == "mp":
321
322
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
323
324
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
325

326
    common_args = [
327
328
        # use half precision for speed and memory savings in CI environment
        "--dtype",
329
        dtype,
330
        "--max-model-len",
331
332
333
334
335
336
337
338
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
339
340
    if runner != "auto":
        common_args.extend(["--runner", runner])
341
342
343
344
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
345
346
347
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
348
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
349

350
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
351
    testing_ray_compiled_graph = False
352
353
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
354
355
        # For V1, test Ray Compiled Graph for all the tests
        # For V0, test Ray Compiled Graph for a subset of the tests
356
        pp_env = {
357
            "VLLM_USE_V1": vllm_major_version,
358
359
360
361
362
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
363
        # terminate because of a Ray Compiled Graph issue.
364
        common_args.append("--disable-frontend-multiprocessing")
365
        testing_ray_compiled_graph = True
366
367
368
369
370
    elif distributed_backend == "mp":
        # Both V0/V1 of multiprocessing executor support PP
        pp_env = {
            "VLLM_USE_V1": vllm_major_version,
        }
371
372
373
    else:
        pp_env = None

374
375
376
377
    tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }

378
379
    pp_args = [
        *common_args,
380
        "--pipeline-parallel-size",
381
        str(pp_size),
382
        "--tensor-parallel-size",
383
        str(tp_size),
384
        "--distributed-executor-backend",
385
        distributed_backend,
386
    ]
387
388
389
390
391
392
393

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
394
        *common_args,
395
        "--tensor-parallel-size",
396
        str(tp_size),
397
398
399
400
        "--distributed-executor-backend",
        "mp",
    ]

401
    try:
402
403
404
405
406
407
        compare_two_settings(model_id,
                             pp_args,
                             tp_args,
                             pp_env,
                             tp_env,
                             method=method)
408
    except Exception:
409
410
        if testing_ray_compiled_graph and vllm_major_version == "0":
            # Ray Compiled Graph tests are flaky for V0,
411
412
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
413
414
        else:
            raise
415
416
417


@pytest.mark.parametrize(
418
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
419
     "runner", "test_options"),
420
    [
421
422
        params for model_id, settings in TEXT_GENERATION_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
423
424
    ],
)
425
@create_new_process_for_each_test()
426
def test_tp_language_generation(
427
    model_id: str,
428
429
    parallel_setup: ParallelSetup,
    distributed_backend: str,
430
    vllm_major_version: str,
431
    runner: RunnerOption,
432
    test_options: PPTestOptions,
433
434
    num_gpus_available,
):
435
    _compare_tp(model_id,
436
437
                parallel_setup,
                distributed_backend,
438
                vllm_major_version,
439
                runner,
440
                test_options,
441
                num_gpus_available,
442
443
                method="generate",
                is_multimodal=False)
444
445
446


@pytest.mark.parametrize(
447
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
448
     "runner", "test_options"),
449
    [
450
451
        params for model_id, settings in EMBEDDING_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
452
453
    ],
)
454
@create_new_process_for_each_test()
455
def test_tp_language_embedding(
456
    model_id: str,
457
458
    parallel_setup: ParallelSetup,
    distributed_backend: str,
459
    vllm_major_version: str,
460
    runner: RunnerOption,
461
    test_options: PPTestOptions,
462
463
    num_gpus_available,
):
464
    _compare_tp(model_id,
465
466
                parallel_setup,
                distributed_backend,
467
                vllm_major_version,
468
                runner,
469
                test_options,
470
                num_gpus_available,
471
472
                method="encode",
                is_multimodal=False)
473
474
475


@pytest.mark.parametrize(
476
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
477
     "runner", "test_options"),
478
    [
479
480
        params for model_id, settings in MULTIMODAL_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
481
482
    ],
)
483
@create_new_process_for_each_test()
484
def test_tp_multimodal_generation(
485
    model_id: str,
486
487
    parallel_setup: ParallelSetup,
    distributed_backend: str,
488
    vllm_major_version: str,
489
    runner: RunnerOption,
490
    test_options: PPTestOptions,
491
492
    num_gpus_available,
):
493
    _compare_tp(model_id,
494
495
                parallel_setup,
                distributed_backend,
496
                vllm_major_version,
497
                runner,
498
                test_options,
499
                num_gpus_available,
500
501
                method="generate",
                is_multimodal=True)