test_pipeline_parallel.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
9
import os
10
from dataclasses import dataclass
11
from typing import List, Literal, NamedTuple, Optional
12

13
14
import pytest

15
from vllm.config import TaskOption
16
17
from vllm.logger import init_logger

18
from ..utils import compare_two_settings, fork_new_process_for_each_test
19

20
21
logger = init_logger("test_pipeline_parallel")

22
23
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

24

25
26
27
28
29
30
31
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


32
33
34
35
class PPTestOptions(NamedTuple):
    multi_node_only: bool
    trust_remote_code: bool
    tokenizer_mode: Optional[str]
36
37
    load_format: Optional[str] = None
    hf_overrides: Optional[str] = None
38
39


40
41
42
@dataclass
class PPTestSettings:
    parallel_setups: List[ParallelSetup]
43
44
45
46
    # NOTE: the length of distributed_backends and
    # vllm_major_versions should be the same, and they
    # are first zipped together to iterate over all
    # test settings.
47
    distributed_backends: List[str]
48
49
    # vllm major version: "0" for V0, "1" for V1
    vllm_major_versions: List[str]
50
    task: TaskOption
51
    test_options: PPTestOptions
52

53
54
55
56
57
58
59
    def __post_init__(self):
        if len(self.distributed_backends) != len(self.vllm_major_versions):
            raise ValueError(
                f"Length mismatch: distributed_backends "
                f"({len(self.distributed_backends)}) != "
                f"vllm_major_versions ({len(self.vllm_major_versions)})")

60
61
62
63
64
    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
65
        multi_node_only: bool = False,
66
        task: TaskOption = "auto",
67
68
        trust_remote_code: bool = False,
        tokenizer_mode: Optional[str] = None,
69
70
        load_format: Optional[str] = None,
        hf_overrides: Optional[str] = None,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
95
96
97
            # only ray is supported for V1
            distributed_backends=["mp", "ray", "ray"],
            vllm_major_versions=["0", "0", "1"],
98
            task=task,
99
100
            test_options=PPTestOptions(multi_node_only=multi_node_only,
                                       trust_remote_code=trust_remote_code,
101
102
103
                                       tokenizer_mode=tokenizer_mode,
                                       load_format=load_format,
                                       hf_overrides=hf_overrides),
104
105
106
107
108
109
110
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
111
        task: TaskOption = "auto",
112
        multi_node_only: bool = False,
113
114
        trust_remote_code: bool = False,
        tokenizer_mode: Optional[str] = None,
115
116
        load_format: Optional[str] = None,
        hf_overrides: Optional[str] = None,
117
118
119
120
121
122
123
124
125
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
126
            vllm_major_versions=["0"],
127
            task=task,
128
129
            test_options=PPTestOptions(multi_node_only=multi_node_only,
                                       trust_remote_code=trust_remote_code,
130
131
132
                                       tokenizer_mode=tokenizer_mode,
                                       load_format=load_format,
                                       hf_overrides=hf_overrides),
133
134
135
        )

    def iter_params(self, model_name: str):
136
137
        opts = self.test_options

138
        for parallel_setup in self.parallel_setups:
139
140
141
            for backend, vllm_major_version in zip(self.distributed_backends,
                                                   self.vllm_major_versions):
                yield (model_name, parallel_setup, backend, vllm_major_version,
142
                       self.task, opts)
143
144


145
146
147
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

148
# yapf: disable
149
150
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
151
152
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
153
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True),  # noqa: E501
154
155
156
157
158
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
    "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True),  # noqa: E501
159
    "databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
    "google/gemma-2b": PPTestSettings.fast(),
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
    "EleutherAI/pythia-12b": PPTestSettings.fast(),
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
    "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
176
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
177
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
178
    "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
179
180
181
182
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
183
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
184
185
186
187
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
188
    "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
189
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
190
191
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
192
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
193
194
    "microsoft/phi-2": PPTestSettings.fast(),
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
195
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'),  # noqa: E501
196
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
197
    "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
198
199
200
201
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
202
203
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
204
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
205
206
207
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
208
209
}

210
211
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
212
213
214
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True),  # noqa: E501
215
216
}

217
218
MULTIMODAL_MODELS = {
    # [Decoder-only]
219
220
221
    "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
222
    "THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True),
223
224
225
226
227
228
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
229
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True),
230
231
232
    "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"),  # noqa: E501
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
233
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
234
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
235
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
236
    # [Encoder-decoder]
237
    # TODO: Implement PP
238
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
239
240
241
}
# yapf: enable

242
# NOTE: You can update this on your local machine to run specific tests
243
TEST_MODELS = [
244
    # [LANGUAGE GENERATION]
245
    "microsoft/Phi-3.5-MoE-instruct",
246
    "meta-llama/Meta-Llama-3-8B",
247
248
249
250
251
    "ibm/PowerLM-3b",
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
252
253
    "OpenGVLab/InternVL2-1B",
    "microsoft/Phi-3-vision-128k-instruct",
254
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
255
256
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
257
258
259
]


260
261
262
263
def _compare_tp(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
264
    vllm_major_version: str,
265
    task: TaskOption,
266
    test_options: PPTestOptions,
267
268
    num_gpus_available: int,
    *,
269
    method: Literal["generate", "encode"],
270
):
271
272
273
274
275
276
277
278
279
280
281
282
283
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
    (
        multi_node_only,
        trust_remote_code,
        tokenizer_mode,
        load_format,
        hf_overrides,
    ) = test_options
284

285
286
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
287
    if VLLM_MULTI_NODE and distributed_backend == "mp":
288
289
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
290
291
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
292

293
    common_args = [
294
295
        # use half precision for speed and memory savings in CI environment
        "--dtype",
296
        "float16",
297
        "--max-model-len",
298
299
300
301
302
303
304
305
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
306
307
    if task != "auto":
        common_args.extend(["--task", task])
308
309
310
311
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
312
313
314
315
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", hf_overrides])
316

317
318
319
320
321
    specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
    if distributed_backend == "ray" and (vllm_major_version == "1"
                                         or specific_case):
        # For V1, test Ray ADAG for all the tests
        # For V0, test Ray ADAG for a subset of the tests
322
        pp_env = {
323
            "VLLM_USE_V1": vllm_major_version,
324
325
326
327
328
329
330
331
332
333
334
335
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
        # terminate because of aDAG issue.
        common_args.append("--disable-frontend-multiprocessing")
    else:
        pp_env = None

    pp_args = [
        *common_args,
336
        "--pipeline-parallel-size",
337
        str(pp_size),
338
        "--tensor-parallel-size",
339
        str(tp_size),
340
        "--distributed-executor-backend",
341
        distributed_backend,
342
    ]
343
344
345
346
347
348
349

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
350
        *common_args,
351
        "--tensor-parallel-size",
352
        str(tp_size),
353
354
355
356
        "--distributed-executor-backend",
        "mp",
    ]

357
    try:
358
359
360
361
362
        compare_two_settings(model_name,
                             pp_args,
                             tp_args,
                             pp_env,
                             method=method)
363
364
365
366
367
368
    except Exception:
        if pp_env is None:
            raise
        else:
            # Ray ADAG tests are flaky, so we don't want to fail the test
            logger.exception("Ray ADAG tests failed")
369
370
371


@pytest.mark.parametrize(
372
373
    ("model_name", "parallel_setup", "distributed_backend",
     "vllm_major_version", "task", "test_options"),
374
    [
375
        params for model_name, settings in TEXT_GENERATION_MODELS.items()
376
377
378
379
380
381
382
383
384
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
385
    vllm_major_version: str,
386
    task: TaskOption,
387
    test_options: PPTestOptions,
388
389
390
391
392
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
393
                vllm_major_version,
394
                task,
395
                test_options,
396
397
398
399
400
                num_gpus_available,
                method="generate")


@pytest.mark.parametrize(
401
402
    ("model_name", "parallel_setup", "distributed_backend",
     "vllm_major_version", "task", "test_options"),
403
    [
404
        params for model_name, settings in EMBEDDING_MODELS.items()
405
406
407
408
409
410
411
412
413
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
414
    vllm_major_version: str,
415
    task: TaskOption,
416
    test_options: PPTestOptions,
417
418
419
420
421
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
422
                vllm_major_version,
423
                task,
424
                test_options,
425
426
427
428
429
                num_gpus_available,
                method="encode")


@pytest.mark.parametrize(
430
431
    ("model_name", "parallel_setup", "distributed_backend",
     "vllm_major_version", "task", "test_options"),
432
    [
433
        params for model_name, settings in MULTIMODAL_MODELS.items()
434
435
436
437
438
439
440
441
442
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
443
    vllm_major_version: str,
444
    task: TaskOption,
445
    test_options: PPTestOptions,
446
447
448
449
450
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
451
                vllm_major_version,
452
                task,
453
                test_options,
454
455
                num_gpus_available,
                method="generate")