test_pipeline_parallel.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
 (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
 important to set the distributed backend to "mp" to avoid Ray scheduling
 all workers in a node other than the head node, which can cause the test
 to fail.
"""
9
import os
10
from dataclasses import dataclass
11
from typing import List, Literal, NamedTuple, Optional
12

13
14
import pytest

15
from vllm.config import TaskOption
16
17
from vllm.logger import init_logger

18
from ..utils import compare_two_settings, fork_new_process_for_each_test
19

20
21
logger = init_logger("test_pipeline_parallel")

22
23
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"

24

25
26
27
28
29
30
31
class ParallelSetup(NamedTuple):
    tp_size: int
    pp_size: int
    eager_mode: bool
    chunked_prefill: bool


32
33
34
35
class PPTestOptions(NamedTuple):
    multi_node_only: bool
    trust_remote_code: bool
    tokenizer_mode: Optional[str]
36
37
    load_format: Optional[str] = None
    hf_overrides: Optional[str] = None
38
39


40
41
42
43
@dataclass
class PPTestSettings:
    parallel_setups: List[ParallelSetup]
    distributed_backends: List[str]
44
    task: TaskOption
45
    test_options: PPTestOptions
46
47
48
49
50
51

    @staticmethod
    def detailed(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
52
        multi_node_only: bool = False,
53
        task: TaskOption = "auto",
54
55
        trust_remote_code: bool = False,
        tokenizer_mode: Optional[str] = None,
56
57
        load_format: Optional[str] = None,
        hf_overrides: Optional[str] = None,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=False),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=tp_base,
                              pp_size=2 * pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=False,
                              chunked_prefill=True),
                ParallelSetup(tp_size=2 * tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp", "ray"],
83
            task=task,
84
85
            test_options=PPTestOptions(multi_node_only=multi_node_only,
                                       trust_remote_code=trust_remote_code,
86
87
88
                                       tokenizer_mode=tokenizer_mode,
                                       load_format=load_format,
                                       hf_overrides=hf_overrides),
89
90
91
92
93
94
95
        )

    @staticmethod
    def fast(
        *,
        tp_base: int = 1,
        pp_base: int = 2,
96
        task: TaskOption = "auto",
97
        multi_node_only: bool = False,
98
99
        trust_remote_code: bool = False,
        tokenizer_mode: Optional[str] = None,
100
101
        load_format: Optional[str] = None,
        hf_overrides: Optional[str] = None,
102
103
104
105
106
107
108
109
110
    ):
        return PPTestSettings(
            parallel_setups=[
                ParallelSetup(tp_size=tp_base,
                              pp_size=pp_base,
                              eager_mode=True,
                              chunked_prefill=False),
            ],
            distributed_backends=["mp"],
111
            task=task,
112
113
            test_options=PPTestOptions(multi_node_only=multi_node_only,
                                       trust_remote_code=trust_remote_code,
114
115
116
                                       tokenizer_mode=tokenizer_mode,
                                       load_format=load_format,
                                       hf_overrides=hf_overrides),
117
118
119
        )

    def iter_params(self, model_name: str):
120
121
        opts = self.test_options

122
123
124
        for parallel_setup in self.parallel_setups:
            for distributed_backend in self.distributed_backends:
                yield (model_name, parallel_setup, distributed_backend,
125
                       self.task, opts)
126
127


128
129
130
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
# The values displayed here are only a rough indicator of the size of the model

131
# yapf: disable
132
133
TEXT_GENERATION_MODELS = {
    # [Decoder-only]
134
135
    # Uses Llama
    # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
136
    "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True),  # noqa: E501
137
138
139
140
141
    "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
    "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "bigscience/bloomz-1b1": PPTestSettings.fast(),
    "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
    "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True),  # noqa: E501
142
    "databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
    "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
    "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
    "tiiuae/falcon-7b": PPTestSettings.fast(),
    "google/gemma-2b": PPTestSettings.fast(),
    "google/gemma-2-9b": PPTestSettings.fast(),
    "gpt2": PPTestSettings.fast(),
    "bigcode/starcoder": PPTestSettings.fast(),
    "EleutherAI/gpt-j-6b": PPTestSettings.fast(),
    "EleutherAI/pythia-12b": PPTestSettings.fast(),
    "ibm/PowerLM-3b": PPTestSettings.fast(),
    "ibm/PowerMoE-3b": PPTestSettings.fast(),
    # Uses Llama
    # "internlm/internlm-chat-7b": PPTestSettings.fast(),
    "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
159
    "inceptionai/jais-13b-chat": PPTestSettings.fast(),
160
    "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
161
    "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
162
163
164
165
    "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
    "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
    # Uses Llama
    # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
166
    "state-spaces/mamba-130m-hf": PPTestSettings.fast(),
167
168
169
170
    "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
    "mosaicml/mpt-7b": PPTestSettings.fast(),
    "nvidia/Minitron-8B-Base": PPTestSettings.fast(),
    "allenai/OLMo-1B-hf": PPTestSettings.fast(),
171
    "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(),
172
    "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(),
173
174
    "facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
    "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
175
    "adept/persimmon-8b-chat": PPTestSettings.fast(),
176
177
    "microsoft/phi-2": PPTestSettings.fast(),
    "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
178
    "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True, load_format="dummy", hf_overrides='{"num_hidden_layers": 4, "hidden_size": 512, "intermediate_size": 800, "num_attention_heads": 4, "num_key_value_heads": 1}'),  # noqa: E501
179
    "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
180
    "Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
181
182
183
184
    "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
    "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
    "bigcode/starcoder2-3b": PPTestSettings.fast(),
    "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
185
186
    # FIXME: Cannot load tokenizer in latest transformers version.
    # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
187
    # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
188
189
190
    # [Encoder-only]
    # TODO: Implement PP
    # "facebook/bart-base": PPTestSettings.fast(),
191
192
}

193
194
EMBEDDING_MODELS = {  # type: ignore[var-annotated]
    # [Text-only]
195
196
197
    "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
    "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
    "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True),  # noqa: E501
198
199
}

200
201
MULTIMODAL_MODELS = {
    # [Decoder-only]
202
203
204
    "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
    "facebook/chameleon-7b": PPTestSettings.fast(),
    "adept/fuyu-8b": PPTestSettings.fast(),
205
    "THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True),
206
207
208
209
210
211
    "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True),
    "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(),
    "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(),
    "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(),
    "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
    "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True),
212
    "allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True),
213
214
215
    "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
    "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"),  # noqa: E501
    "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
216
    "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
217
    "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
218
    "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True),  # noqa: E501
219
    # [Encoder-decoder]
220
    # TODO: Implement PP
221
    # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
222
223
224
}
# yapf: enable

225
# NOTE: You can update this on your local machine to run specific tests
226
TEST_MODELS = [
227
    # [LANGUAGE GENERATION]
228
    "microsoft/Phi-3.5-MoE-instruct",
229
    "meta-llama/Meta-Llama-3-8B",
230
231
232
233
234
    "ibm/PowerLM-3b",
    # [LANGUAGE EMBEDDING]
    "intfloat/e5-mistral-7b-instruct",
    "BAAI/bge-multilingual-gemma2",
    # [MULTIMODAL GENERATION]
235
236
    "OpenGVLab/InternVL2-1B",
    "microsoft/Phi-3-vision-128k-instruct",
237
    "fixie-ai/ultravox-v0_5-llama-3_2-1b",
238
239
    # [LANGUAGE GENERATION - HYBRID ARCH]
    "ai21labs/Jamba-tiny-dev",
240
241
242
]


243
244
245
246
def _compare_tp(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
247
    task: TaskOption,
248
    test_options: PPTestOptions,
249
250
    num_gpus_available: int,
    *,
251
    method: Literal["generate", "encode"],
252
):
253
254
255
256
257
258
259
260
261
262
263
264
265
    (
        tp_size,
        pp_size,
        eager_mode,
        chunked_prefill,
    ) = parallel_setup
    (
        multi_node_only,
        trust_remote_code,
        tokenizer_mode,
        load_format,
        hf_overrides,
    ) = test_options
266

267
268
    if num_gpus_available < tp_size * pp_size:
        pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
269
    if VLLM_MULTI_NODE and distributed_backend == "mp":
270
271
        pytest.skip("Skipping multi-node pipeline parallel test for "
                    "multiprocessing distributed backend")
272
273
    if multi_node_only and not VLLM_MULTI_NODE:
        pytest.skip("Not in multi-node setting")
274

275
    common_args = [
276
277
        # use half precision for speed and memory savings in CI environment
        "--dtype",
278
        "float16",
279
        "--max-model-len",
280
281
282
283
284
285
286
287
        "2048",
        "--max-num-seqs",
        "8",
    ]
    if chunked_prefill:
        common_args.append("--enable-chunked-prefill")
    if eager_mode:
        common_args.append("--enforce-eager")
288
289
    if task != "auto":
        common_args.extend(["--task", task])
290
291
292
293
    if trust_remote_code:
        common_args.append("--trust-remote-code")
    if tokenizer_mode:
        common_args.extend(["--tokenizer-mode", tokenizer_mode])
294
295
296
297
    if load_format:
        common_args.extend(["--load-format", load_format])
    if hf_overrides:
        common_args.extend(["--hf-overrides", hf_overrides])
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

    if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2
            and chunked_prefill):
        # Test Ray ADAG for a subset of the tests
        pp_env = {
            "VLLM_USE_RAY_COMPILED_DAG": "1",
            "VLLM_USE_RAY_SPMD_WORKER": "1",
            "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
        }
        # Temporary. Currently when zeromq + SPMD is used, it does not properly
        # terminate because of aDAG issue.
        common_args.append("--disable-frontend-multiprocessing")
    else:
        pp_env = None

    pp_args = [
        *common_args,
315
        "--pipeline-parallel-size",
316
        str(pp_size),
317
        "--tensor-parallel-size",
318
        str(tp_size),
319
        "--distributed-executor-backend",
320
        distributed_backend,
321
    ]
322
323
324
325
326
327
328

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
329
        *common_args,
330
        "--tensor-parallel-size",
331
        str(tp_size),
332
333
334
335
        "--distributed-executor-backend",
        "mp",
    ]

336
    try:
337
338
339
340
341
        compare_two_settings(model_name,
                             pp_args,
                             tp_args,
                             pp_env,
                             method=method)
342
343
344
345
346
347
    except Exception:
        if pp_env is None:
            raise
        else:
            # Ray ADAG tests are flaky, so we don't want to fail the test
            logger.exception("Ray ADAG tests failed")
348
349
350


@pytest.mark.parametrize(
351
    ("model_name", "parallel_setup", "distributed_backend", "task",
352
     "test_options"),
353
    [
354
        params for model_name, settings in TEXT_GENERATION_MODELS.items()
355
356
357
358
359
360
361
362
363
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_language_generation(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
364
    task: TaskOption,
365
    test_options: PPTestOptions,
366
367
368
369
370
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
371
                task,
372
                test_options,
373
374
375
376
377
                num_gpus_available,
                method="generate")


@pytest.mark.parametrize(
378
    ("model_name", "parallel_setup", "distributed_backend", "task",
379
     "test_options"),
380
    [
381
        params for model_name, settings in EMBEDDING_MODELS.items()
382
383
384
385
386
387
388
389
390
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_language_embedding(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
391
    task: TaskOption,
392
    test_options: PPTestOptions,
393
394
395
396
397
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
398
                task,
399
                test_options,
400
401
402
403
404
                num_gpus_available,
                method="encode")


@pytest.mark.parametrize(
405
    ("model_name", "parallel_setup", "distributed_backend", "task",
406
     "test_options"),
407
    [
408
        params for model_name, settings in MULTIMODAL_MODELS.items()
409
410
411
412
413
414
415
416
417
        for params in settings.iter_params(model_name)
        if model_name in TEST_MODELS
    ],
)
@fork_new_process_for_each_test
def test_tp_multimodal_generation(
    model_name: str,
    parallel_setup: ParallelSetup,
    distributed_backend: str,
418
    task: TaskOption,
419
    test_options: PPTestOptions,
420
421
422
423
424
    num_gpus_available,
):
    _compare_tp(model_name,
                parallel_setup,
                distributed_backend,
425
                task,
426
                test_options,
427
428
                num_gpus_available,
                method="generate")