test_pipeline_parallel.py 17.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
9
import json
10
import os
11
from dataclasses import dataclass
12
from typing import Literal, NamedTuple, Optional
13

14
15
import pytest

16
from vllm.config import TaskOption
17
18
from vllm.logger import init_logger

19
from ..models.registry import HF_EXAMPLE_MODELS
20
from ..utils import compare_two_settings, create_new_process_for_each_test
21

22
23
logger = init_logger("test_pipeline_parallel")

24
25
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

26

27
28
29
30
31
32
33
34
35
36
37
38
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    For PP, we fall back to V0 by default. This means
    that the TP baseline runs with V1 while the PP engine
    runs with V0. This gives divergent results with dummy
    weights. Once we enable V1 by default for PP, we can
    remove this.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


39
40
41
42
43
44
45
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


46
47
class PPTestOptions(NamedTuple):
    multi_node_only: bool
48
    load_format: Optional[str] = None
49
50


51
52
@dataclass
class PPTestSettings:
53
    parallel_setups: list[ParallelSetup]
54
55
56
57
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
58
    distributed_backends: list[str]
59
    # vllm major version: "0" for V0, "1" for V1
60
    vllm_major_versions: list[str]
61
    task: TaskOption
62
    test_options: PPTestOptions
63

64
65
66
67
68
69
70
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

71
72
73
74
75
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
76
        multi_node_only: bool = False,
77
        task: TaskOption = "auto",
78
        load_format: Optional[str] = None,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
103
104
            distributed_backends=["mp", "mp", "ray", "ray"],
            vllm_major_versions=["0", "1", "0", "1"],
105
            task=task,
106
            test_options=PPTestOptions(multi_node_only=multi_node_only,
107
                                       load_format=load_format),
108
109
110
111
112
113
114
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
115
        task: TaskOption = "auto",
116
        multi_node_only: bool = False,
117
        load_format: Optional[str] = None,
118
119
120
121
122
123
124
125
126
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
127
            vllm_major_versions=["0"],
128
            task=task,
129
            test_options=PPTestOptions(multi_node_only=multi_node_only,
130
                                       load_format=load_format),
131
132
        )

133
    def iter_params(self, model_id: str):
134
135
        opts = self.test_options

136
        for parallel_setup in self.parallel_setups:
137
138
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
139
                yield (model_id, parallel_setup, backend, vllm_major_version,
140
                       self.task, opts)
141
142


143
144
145
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

146
# yapf: disable
147
148
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
149
150
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
151
152
153
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
154
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
155
156
157
158
    "THUDM/chatglm3-6b": PPTestSettings.fast(),
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
    "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
159
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
160
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
161
162
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
163
    "google/gemma-1.1-2b-it": PPTestSettings.fast(),
164
165
166
167
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
168
    "EleutherAI/pythia-1.4b": PPTestSettings.fast(),
169
170
171
172
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
173
    "internlm/internlm2-chat-7b": PPTestSettings.fast(),
174
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
175
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
176
    "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
177
    # Tests TransformersForCausalLM
178
    "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
179
180
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
181
182
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
183
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
184
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
185
186
187
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
188
    "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
189
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
190
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
191
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
192
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
193
    "microsoft/phi-2": PPTestSettings.fast(),
194
195
196
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"),  # noqa: E501
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
197
    "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(),
198
199
200
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
201
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
202
203
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
204
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
205
206
207
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
208
209
}

210
211
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
212
213
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
214
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
215
216
}

217
218
MULTIMODAL_MODELS = {
    # [Decoder-only]
219
    "Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
220
221
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
222
223
    "THUDM/glm-4v-9b": PPTestSettings.fast(),
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
224
225
226
227
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
228
229
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
230
    "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
231
232
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
233
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
234
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
235
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
236
    # [Encoder-decoder]
237
    # TODO: Implement PP
238
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
239
240
241
}
# yapf: enable

242
# NOTE: You can update this on your local machine to run specific tests
243
TEST_MODELS = [
244
    # [LANGUAGE GENERATION]
245
    "microsoft/Phi-3.5-MoE-instruct",
246
    "meta-llama/Llama-3.2-1B-Instruct",
247
    "ArthurZ/Ilama-3.2-1B",
248
249
250
251
252
    "ibm/PowerLM-3b",
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
253
    "OpenGVLab/InternVL2-1B",
254
    "microsoft/Phi-3.5-vision-instruct",
255
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
256
257
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
258
259
260
]


261
def _compare_tp(
262
    model_id: str,
263
264
    parallel_setup: ParallelSetup,
    distributed_backend: str,
265
    vllm_major_version: str,
266
    task: TaskOption,
267
    test_options: PPTestOptions,
268
269
    num_gpus_available: int,
    *,
270
    method: Literal["generate", "encode"],
271
    is_multimodal: bool,
272
):
273
274
275
276
277
278
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
279
280
281
282
283
284
285
286
287
288
289
290
291

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
292
293
294
295
296
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
297
298
299
300
301
302
303
304
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")
305

306
307
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
308
    if VLLM_MULTI_NODE and distributed_backend == "mp":
309
310
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
311
312
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
313

314
    common_args = [
315
316
        # use half precision for speed and memory savings in CI environment
        "--dtype",
317
        "float16",
318
        "--max-model-len",
319
320
321
322
323
324
325
326
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
327
328
    if task != "auto":
        common_args.extend(["--task", task])
329
330
331
332
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
333
334
335
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
336
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
337

338
339
340
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
341
342
        # For V1, test Ray Compiled Graph for all the tests
        # For V0, test Ray Compiled Graph for a subset of the tests
343
        pp_env = {
344
            "VLLM_USE_V1": vllm_major_version,
345
346
347
348
349
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
350
        # terminate because of a Ray Compiled Graph issue.
351
        common_args.append("--disable-frontend-multiprocessing")
352
353
354
355
356
    elif distributed_backend == "mp":
        # Both V0/V1 of multiprocessing executor support PP
        pp_env = {
            "VLLM_USE_V1": vllm_major_version,
        }
357
358
359
    else:
        pp_env = None

360
361
362
363
    tp_env = {
        "VLLM_USE_V1": vllm_major_version,
    }

364
365
    pp_args = [
        *common_args,
366
        "--pipeline-parallel-size",
367
        str(pp_size),
368
        "--tensor-parallel-size",
369
        str(tp_size),
370
        "--distributed-executor-backend",
371
        distributed_backend,
372
    ]
373
374
375
376
377
378
379

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
380
        *common_args,
381
        "--tensor-parallel-size",
382
        str(tp_size),
383
384
385
386
        "--distributed-executor-backend",
        "mp",
    ]

387
    try:
388
389
390
391
392
393
        compare_two_settings(model_id,
                             pp_args,
                             tp_args,
                             pp_env,
                             tp_env,
                             method=method)
394
    except Exception:
395
396
397
        testing_ray_compiled_graph = pp_env is not None
        if testing_ray_compiled_graph and vllm_major_version == "0":
            # Ray Compiled Graph tests are flaky for V0,
398
399
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
400
401
        else:
            raise
402
403
404


@pytest.mark.parametrize(
405
406
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
407
    [
408
409
        params for model_id, settings in TEXT_GENERATION_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
410
411
    ],
)
412
@create_new_process_for_each_test()
413
def test_tp_language_generation(
414
    model_id: str,
415
416
    parallel_setup: ParallelSetup,
    distributed_backend: str,
417
    vllm_major_version: str,
418
    task: TaskOption,
419
    test_options: PPTestOptions,
420
421
    num_gpus_available,
):
422
    _compare_tp(model_id,
423
424
                parallel_setup,
                distributed_backend,
425
                vllm_major_version,
426
                task,
427
                test_options,
428
                num_gpus_available,
429
430
                method="generate",
                is_multimodal=False)
431
432
433


@pytest.mark.parametrize(
434
435
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
436
    [
437
438
        params for model_id, settings in EMBEDDING_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
439
440
    ],
)
441
@create_new_process_for_each_test()
442
def test_tp_language_embedding(
443
    model_id: str,
444
445
    parallel_setup: ParallelSetup,
    distributed_backend: str,
446
    vllm_major_version: str,
447
    task: TaskOption,
448
    test_options: PPTestOptions,
449
450
    num_gpus_available,
):
451
    _compare_tp(model_id,
452
453
                parallel_setup,
                distributed_backend,
454
                vllm_major_version,
455
                task,
456
                test_options,
457
                num_gpus_available,
458
459
                method="encode",
                is_multimodal=False)
460
461
462


@pytest.mark.parametrize(
463
464
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
465
    [
466
467
        params for model_id, settings in MULTIMODAL_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
468
469
    ],
)
470
@create_new_process_for_each_test()
471
def test_tp_multimodal_generation(
472
    model_id: str,
473
474
    parallel_setup: ParallelSetup,
    distributed_backend: str,
475
    vllm_major_version: str,
476
    task: TaskOption,
477
    test_options: PPTestOptions,
478
479
    num_gpus_available,
):
480
    _compare_tp(model_id,
481
482
                parallel_setup,
                distributed_backend,
483
                vllm_major_version,
484
                task,
485
                test_options,
486
                num_gpus_available,
487
488
                method="generate",
                is_multimodal=True)