test_pipeline_parallel.py 18.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
10
import json
11
import os
12
from dataclasses import dataclass
13
from typing import Literal, NamedTuple, Optional
14

15
16
import pytest

17
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
18
from vllm.logger import init_logger
19
from vllm.transformers_utils.config import get_config
20

21
from ..models.registry import HF_EXAMPLE_MODELS
22
from ..utils import compare_two_settings, create_new_process_for_each_test
23

24
25
logger = init_logger("test_pipeline_parallel")

26
27
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

28

29
30
31
32
33
34
35
36
37
38
39
40
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    For PP, we fall back to V0 by default. This means
    that the TP baseline runs with V1 while the PP engine
    runs with V0. This gives divergent results with dummy
    weights. Once we enable V1 by default for PP, we can
    remove this.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


41
42
43
44
45
46
47
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


48
49
class PPTestOptions(NamedTuple):
    multi_node_only: bool
50
    load_format: Optional[str] = None
51
52


53
54
@dataclass
class PPTestSettings:
55
    parallel_setups: list[ParallelSetup]
56
57
58
59
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
60
    distributed_backends: list[str]
61
    # vllm major version: "0" for V0, "1" for V1
62
    vllm_major_versions: list[str]
63
    runner: RunnerOption
64
    test_options: PPTestOptions
65

66
67
68
69
70
71
72
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

73
74
75
76
77
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
78
        multi_node_only: bool = False,
79
        runner: RunnerOption = "auto",
80
        load_format: Optional[str] = None,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
105
106
            distributed_backends=["mp", "mp", "ray", "ray"],
            vllm_major_versions=["0", "1", "0", "1"],
107
            runner=runner,
108
            test_options=PPTestOptions(multi_node_only=multi_node_only,
109
                                       load_format=load_format),
110
111
112
113
114
115
116
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
117
        runner: RunnerOption = "auto",
118
        multi_node_only: bool = False,
119
        load_format: Optional[str] = None,
120
    ):
121
122
        vllm_major_versions = ["1"] if runner == "pooling" else ["0"]

123
124
125
126
127
128
129
130
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
131
            vllm_major_versions=vllm_major_versions,
132
            runner=runner,
133
            test_options=PPTestOptions(multi_node_only=multi_node_only,
134
                                       load_format=load_format),
135
136
        )

137
    def iter_params(self, model_id: str):
138
139
        opts = self.test_options

140
        for parallel_setup in self.parallel_setups:
141
142
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
143
                yield (model_id, parallel_setup, backend, vllm_major_version,
144
                       self.runner, opts)
145
146


147
148
149
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

150
# yapf: disable
151
152
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
153
154
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
155
156
157
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
158
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
159
    "zai-org/chatglm3-6b": PPTestSettings.fast(),
160
161
162
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
    "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
163
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
164
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(tp_base=2),
165
166
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
167
    "google/gemma-1.1-2b-it": PPTestSettings.fast(),
168
169
170
171
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
172
    "EleutherAI/pythia-1.4b": PPTestSettings.fast(),
173
174
175
176
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
177
    "internlm/internlm2-chat-7b": PPTestSettings.fast(),
178
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
179
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
180
    "pfnet/plamo-2-1b": PPTestSettings.fast(),
181
    "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
182
    # Tests TransformersForCausalLM
183
    "hmellor/Ilama-3.2-1B": PPTestSettings.fast(),
184
185
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
186
187
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
188
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
189
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
190
191
192
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
193
    "allenai/OLMo-2-0425-1B": PPTestSettings.fast(),
194
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
195
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
196
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
197
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
198
    "microsoft/phi-2": PPTestSettings.fast(),
199
200
201
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"),  # noqa: E501
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
202
    "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(),
203
204
205
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
206
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
207
208
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
209
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
210
211
212
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
213
214
}

215
216
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
217
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
218
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
219
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
220
        load_format="dummy", runner="pooling"
221
    ),
222
223
}

224
225
MULTIMODAL_MODELS = {
    # [Decoder-only]
226
    "Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
227
228
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
229
    "zai-org/glm-4v-9b": PPTestSettings.fast(),
230
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
231
232
233
234
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
235
236
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
237
    "AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
238
    "AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
239
    "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
240
241
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
242
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
243
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
244
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
245
    # [Encoder-decoder]
246
    # TODO: Implement PP
247
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
248
249
250
}
# yapf: enable

251
# NOTE: You can update this on your local machine to run specific tests
252
TEST_MODELS = [
253
    # [LANGUAGE GENERATION]
254
    "microsoft/Phi-3.5-MoE-instruct",
255
    "meta-llama/Llama-3.2-1B-Instruct",
256
    "hmellor/Ilama-3.2-1B",
257
    "ibm/PowerLM-3b",
258
    "deepseek-ai/DeepSeek-V2-Lite-Chat",
259
260
261
262
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
263
    "OpenGVLab/InternVL2-1B",
264
    "microsoft/Phi-3.5-vision-instruct",
265
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
266
267
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
268
269
270
]


271
def _compare_tp(
272
    model_id: str,
273
274
    parallel_setup: ParallelSetup,
    distributed_backend: str,
275
    vllm_major_version: str,
276
    runner: RunnerOption,
277
    test_options: PPTestOptions,
278
279
    num_gpus_available: int,
    *,
280
    method: Literal["generate", "encode"],
281
    is_multimodal: bool,
282
):
283
284
285
286
287
288
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
289
290
291
292
293
294
295
296
297

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides
298
    hf_config = get_config(model_id, trust_remote_code)
299
300
    skip_tokenizer_init = model_info.skip_tokenizer_init
    max_num_seqs = model_info.max_num_seqs
301
302
303
304

    dtype = "float16"
    if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
        dtype = "bfloat16"
305
306
307
308

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
309
310
311
312
313
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
314
315
316
317
318
319
320
321
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")
322

323
324
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
325
    if VLLM_MULTI_NODE and distributed_backend == "mp":
326
327
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
328
329
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
330

331
    common_args = [
332
333
        # use half precision for speed and memory savings in CI environment
        "--dtype",
334
        dtype,
335
        "--max-model-len",
336
337
338
339
340
341
342
343
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
344
345
    if runner != "auto":
        common_args.extend(["--runner", runner])
346
347
348
349
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
350
351
352
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
353
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
354
355
356
357
    if skip_tokenizer_init:
        common_args.append("--skip-tokenizer-init")
    if max_num_seqs:
        common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
358

359
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
360
    testing_ray_compiled_graph = False
361
362
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
363
364
        # For V1, test Ray Compiled Graph for all the tests
        # For V0, test Ray Compiled Graph for a subset of the tests
365
        pp_env = {
366
            "VLLM_USE_V1": vllm_major_version,
367
368
369
370
371
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
372
        # terminate because of a Ray Compiled Graph issue.
373
        common_args.append("--disable-frontend-multiprocessing")
374
        testing_ray_compiled_graph = True
375
376
377
378
379
    elif distributed_backend == "mp":
        # Both V0/V1 of multiprocessing executor support PP
        pp_env = {
            "VLLM_USE_V1": vllm_major_version,
        }
380
381
382
    else:
        pp_env = None

383
384
385
386
    tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }

387
388
    pp_args = [
        *common_args,
389
        "--pipeline-parallel-size",
390
        str(pp_size),
391
        "--tensor-parallel-size",
392
        str(tp_size),
393
        "--distributed-executor-backend",
394
        distributed_backend,
395
    ]
396
397
398
399
400
401
402

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
403
        *common_args,
404
        "--tensor-parallel-size",
405
        str(tp_size),
406
407
408
409
        "--distributed-executor-backend",
        "mp",
    ]

410
    try:
411
412
413
414
415
416
        compare_two_settings(model_id,
                             pp_args,
                             tp_args,
                             pp_env,
                             tp_env,
                             method=method)
417
    except Exception:
418
419
        if testing_ray_compiled_graph and vllm_major_version == "0":
            # Ray Compiled Graph tests are flaky for V0,
420
421
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
422
423
        else:
            raise
424
425
426


@pytest.mark.parametrize(
427
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
428
     "runner", "test_options"),
429
    [
430
431
        params for model_id, settings in TEXT_GENERATION_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
432
433
    ],
)
434
@create_new_process_for_each_test()
435
def test_tp_language_generation(
436
    model_id: str,
437
438
    parallel_setup: ParallelSetup,
    distributed_backend: str,
439
    vllm_major_version: str,
440
    runner: RunnerOption,
441
    test_options: PPTestOptions,
442
443
    num_gpus_available,
):
444
    _compare_tp(model_id,
445
446
                parallel_setup,
                distributed_backend,
447
                vllm_major_version,
448
                runner,
449
                test_options,
450
                num_gpus_available,
451
452
                method="generate",
                is_multimodal=False)
453
454
455


@pytest.mark.parametrize(
456
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
457
     "runner", "test_options"),
458
    [
459
460
        params for model_id, settings in EMBEDDING_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
461
462
    ],
)
463
@create_new_process_for_each_test()
464
def test_tp_language_embedding(
465
    model_id: str,
466
467
    parallel_setup: ParallelSetup,
    distributed_backend: str,
468
    vllm_major_version: str,
469
    runner: RunnerOption,
470
    test_options: PPTestOptions,
471
472
    num_gpus_available,
):
473
    _compare_tp(model_id,
474
475
                parallel_setup,
                distributed_backend,
476
                vllm_major_version,
477
                runner,
478
                test_options,
479
                num_gpus_available,
480
481
                method="encode",
                is_multimodal=False)
482
483
484


@pytest.mark.parametrize(
485
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
486
     "runner", "test_options"),
487
    [
488
489
        params for model_id, settings in MULTIMODAL_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
490
491
    ],
)
492
@create_new_process_for_each_test()
493
def test_tp_multimodal_generation(
494
    model_id: str,
495
496
    parallel_setup: ParallelSetup,
    distributed_backend: str,
497
    vllm_major_version: str,
498
    runner: RunnerOption,
499
    test_options: PPTestOptions,
500
501
    num_gpus_available,
):
502
    _compare_tp(model_id,
503
504
                parallel_setup,
                distributed_backend,
505
                vllm_major_version,
506
                runner,
507
                test_options,
508
                num_gpus_available,
509
510
                method="generate",
                is_multimodal=True)