test_pipeline_parallel.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
9
import json
10
import os
11
from dataclasses import dataclass
12
from typing import Literal, NamedTuple, Optional
13

14
15
import pytest

16
from vllm.config import TaskOption
17
18
from vllm.logger import init_logger

19
from ..models.registry import HF_EXAMPLE_MODELS
20
from ..utils import compare_two_settings, fork_new_process_for_each_test
21

22
23
logger = init_logger("test_pipeline_parallel")

24
25
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

26

27
28
29
30
31
32
33
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


34
35
class PPTestOptions(NamedTuple):
    multi_node_only: bool
36
    load_format: Optional[str] = None
37
38


39
40
@dataclass
class PPTestSettings:
41
    parallel_setups: list[ParallelSetup]
42
43
44
45
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
46
    distributed_backends: list[str]
47
    # vllm major version: "0" for V0, "1" for V1
48
    vllm_major_versions: list[str]
49
    task: TaskOption
50
    test_options: PPTestOptions
51

52
53
54
55
56
57
58
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

59
60
61
62
63
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
64
        multi_node_only: bool = False,
65
        task: TaskOption = "auto",
66
        load_format: Optional[str] = None,
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
91
92
93
            # only ray is supported for V1
            distributed_backends=["mp", "ray", "ray"],
            vllm_major_versions=["0", "0", "1"],
94
            task=task,
95
            test_options=PPTestOptions(multi_node_only=multi_node_only,
96
                                       load_format=load_format),
97
98
99
100
101
102
103
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
104
        task: TaskOption = "auto",
105
        multi_node_only: bool = False,
106
        load_format: Optional[str] = None,
107
108
109
110
111
112
113
114
115
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
116
            vllm_major_versions=["0"],
117
            task=task,
118
            test_options=PPTestOptions(multi_node_only=multi_node_only,
119
                                       load_format=load_format),
120
121
        )

122
    def iter_params(self, model_id: str):
123
124
        opts = self.test_options

125
        for parallel_setup in self.parallel_setups:
126
127
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
128
                yield (model_id, parallel_setup, backend, vllm_major_version,
129
                       self.task, opts)
130
131


132
133
134
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

135
# yapf: disable
136
137
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
138
139
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
140
141
142
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(),
143
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
144
145
146
147
    "THUDM/chatglm3-6b": PPTestSettings.fast(),
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(load_format="dummy"),
    "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
148
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
149
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
150
151
152
153
154
155
156
157
158
159
160
161
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
    "google/gemma-2b": PPTestSettings.fast(),
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
    "EleutherAI/pythia-12b": PPTestSettings.fast(),
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
162
    "internlm/internlm2-chat-7b": PPTestSettings.fast(),
163
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
164
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
165
    "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
166
167
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(),
168
169
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
170
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
171
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
172
173
174
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
175
    "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
176
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
177
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
178
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(),
179
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
180
    "microsoft/phi-2": PPTestSettings.fast(),
181
182
183
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(),
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"),  # noqa: E501
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(),
184
    "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
185
186
187
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
188
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"),  # noqa: E501
189
190
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
191
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(),
192
193
194
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
195
196
}

197
198
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
199
200
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
201
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
202
203
}

204
205
MULTIMODAL_MODELS = {
    # [Decoder-only]
206
207
208
    "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
209
210
    "THUDM/glm-4v-9b": PPTestSettings.fast(),
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(),
211
212
213
214
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
215
216
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
217
    "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
218
219
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
220
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
221
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
222
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(),
223
    # [Encoder-decoder]
224
    # TODO: Implement PP
225
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
226
227
228
}
# yapf: enable

229
# NOTE: You can update this on your local machine to run specific tests
230
TEST_MODELS = [
231
    # [LANGUAGE GENERATION]
232
    "microsoft/Phi-3.5-MoE-instruct",
233
    "meta-llama/Llama-3.2-1B-Instruct",
234
235
236
237
238
    "ibm/PowerLM-3b",
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
239
    "OpenGVLab/InternVL2-1B",
240
    "microsoft/Phi-3.5-vision-instruct",
241
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
242
243
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
244
245
246
]


247
def _compare_tp(
248
    model_id: str,
249
250
    parallel_setup: ParallelSetup,
    distributed_backend: str,
251
    vllm_major_version: str,
252
    task: TaskOption,
253
    test_options: PPTestOptions,
254
255
    num_gpus_available: int,
    *,
256
    method: Literal["generate", "encode"],
257
    is_multimodal: bool,
258
):
259
260
261
262
263
264
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
265
266
267
268
269
270
271
272
273
274
275
276
277

    multi_node_only, load_format = test_options

    model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
    model_info.check_transformers_version(on_fail="skip")

    trust_remote_code = model_info.trust_remote_code
    tokenizer_mode = model_info.tokenizer_mode
    hf_overrides = model_info.hf_overrides

    if load_format == "dummy":
        # Avoid OOM
        text_overrides = {
278
279
280
281
282
            "num_hidden_layers": 4,
            "hidden_size": 512,
            "intermediate_size": 800,
            "num_attention_heads": 4,
            "num_key_value_heads": 1,
283
284
285
286
287
288
289
290
        }

        if is_multimodal:
            hf_overrides.update({"text_config": text_overrides})
        else:
            hf_overrides.update(text_overrides)
    else:
        model_info.check_available_online(on_fail="skip")
291

292
293
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
294
    if VLLM_MULTI_NODE and distributed_backend == "mp":
295
296
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
297
298
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
299

300
    common_args = [
301
302
        # use half precision for speed and memory savings in CI environment
        "--dtype",
303
        "float16",
304
        "--max-model-len",
305
306
307
308
309
310
311
312
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
313
314
    if task != "auto":
        common_args.extend(["--task", task])
315
316
317
318
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
319
320
321
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
322
        common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
323

324
325
326
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
327
328
        # For V1, test Ray Compiled Graph for all the tests
        # For V0, test Ray Compiled Graph for a subset of the tests
329
        pp_env = {
330
            "VLLM_USE_V1": vllm_major_version,
331
332
333
334
335
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
336
        # terminate because of a Ray Compiled Graph issue.
337
338
339
340
341
342
        common_args.append("--disable-frontend-multiprocessing")
    else:
        pp_env = None

    pp_args = [
        *common_args,
343
        "--pipeline-parallel-size",
344
        str(pp_size),
345
        "--tensor-parallel-size",
346
        str(tp_size),
347
        "--distributed-executor-backend",
348
        distributed_backend,
349
    ]
350
351
352
353
354
355
356

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
357
        *common_args,
358
        "--tensor-parallel-size",
359
        str(tp_size),
360
361
362
363
        "--distributed-executor-backend",
        "mp",
    ]

364
    try:
365
        compare_two_settings(model_id, pp_args, tp_args, pp_env, method=method)
366
367
368
369
    except Exception:
        if pp_env is None:
            raise
        else:
370
371
372
            # Ray Compiled Graph tests are flaky,
            # so we don't want to fail the test
            logger.exception("Ray Compiled Graph tests failed")
373
374
375


@pytest.mark.parametrize(
376
377
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
378
    [
379
380
        params for model_id, settings in TEXT_GENERATION_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
381
382
383
384
    ],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
385
    model_id: str,
386
387
    parallel_setup: ParallelSetup,
    distributed_backend: str,
388
    vllm_major_version: str,
389
    task: TaskOption,
390
    test_options: PPTestOptions,
391
392
    num_gpus_available,
):
393
    _compare_tp(model_id,
394
395
                parallel_setup,
                distributed_backend,
396
                vllm_major_version,
397
                task,
398
                test_options,
399
                num_gpus_available,
400
401
                method="generate",
                is_multimodal=False)
402
403
404


@pytest.mark.parametrize(
405
406
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
407
    [
408
409
        params for model_id, settings in EMBEDDING_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
410
411
412
413
    ],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
414
    model_id: str,
415
416
    parallel_setup: ParallelSetup,
    distributed_backend: str,
417
    vllm_major_version: str,
418
    task: TaskOption,
419
    test_options: PPTestOptions,
420
421
    num_gpus_available,
):
422
    _compare_tp(model_id,
423
424
                parallel_setup,
                distributed_backend,
425
                vllm_major_version,
426
                task,
427
                test_options,
428
                num_gpus_available,
429
430
                method="encode",
                is_multimodal=False)
431
432
433


@pytest.mark.parametrize(
434
435
    ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
     "task", "test_options"),
436
    [
437
438
        params for model_id, settings in MULTIMODAL_MODELS.items()
        for params in settings.iter_params(model_id) if model_id in TEST_MODELS
439
440
441
442
    ],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
443
    model_id: str,
444
445
    parallel_setup: ParallelSetup,
    distributed_backend: str,
446
    vllm_major_version: str,
447
    task: TaskOption,
448
    test_options: PPTestOptions,
449
450
    num_gpus_available,
):
451
    _compare_tp(model_id,
452
453
                parallel_setup,
                distributed_backend,
454
                vllm_major_version,
455
                task,
456
                test_options,
457
                num_gpus_available,
458
459
                method="generate",
                is_multimodal=True)