data_parallel.py 8.14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
            --dp-size=2 \
            --tp-size=2

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
31

32
import os
33
from time import sleep
34
35

from vllm import LLM, SamplingParams
36
from vllm.platforms import current_platform
37
from vllm.utils.network_utils import get_open_port
38
39


40
41
def parse_args():
    import argparse
42

43
    parser = argparse.ArgumentParser(description="Data Parallel Inference")
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    parser.add_argument(
        "--model",
        type=str,
        default="ibm-research/PowerMoE-3b",
        help="Model name or path",
    )
    parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
    parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
    parser.add_argument(
        "--node-size", type=int, default=1, help="Total number of nodes"
    )
    parser.add_argument(
        "--node-rank", type=int, default=0, help="Rank of the current node"
    )
    parser.add_argument(
        "--master-addr", type=str, default="", help="Master node IP address"
    )
    parser.add_argument("--master-port", type=int, default=0, help="Master node port")
    parser.add_argument(
        "--enforce-eager", action="store_true", help="Enforce eager mode execution."
    )
    parser.add_argument(
        "--trust-remote-code", action="store_true", help="Trust remote code."
    )
68
69
70
71
72
73
    parser.add_argument(
        "--max-num-seqs",
        type=int,
        default=64,
        help=("Maximum number of sequences to be processed in a single iteration."),
    )
74
75
76
77
78
79
80
81
82
83
84
    parser.add_argument(
        "--max-model-len",
        type=int,
        help=("Maximum number of tokens to be processed in a single iteration."),
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=300,
        help=("Number of seconds before unresponsive process is killed."),
    )
85
86
87
88
89
90
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.8,
        help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
    )
91
92
93
94
95
    parser.add_argument(
        "--enable-dbo",
        action="store_true",
        help=("Enable microbatched execution"),
    )
96
97
98
    parser.add_argument(
        "--compilation-config",
        type=int,
99
        help=("Compilation optimization (O) mode 0-3."),
100
    )
101
102
103
104
    parser.add_argument(
        "--quantization",
        type=str,
    )
105
106
107
108
109
110
111
    parser.add_argument(
        "--disable-expert-parallel",
        dest="enable_expert_parallel",
        action="store_false",
        help="Disable expert parallel (default: enabled).",
    )
    parser.set_defaults(enable_expert_parallel=True)
112
113
114
    return parser.parse_args()


115
116
117
118
119
120
121
122
123
def main(
    model,
    dp_size,
    local_dp_rank,
    global_dp_rank,
    dp_master_ip,
    dp_master_port,
    GPUs_per_dp_rank,
    enforce_eager,
124
    enable_expert_parallel,
125
    trust_remote_code,
126
    max_num_seqs,
127
    max_model_len,
128
    compilation_config,
129
    gpu_memory_utilization,
130
    enable_dbo,
131
    quantization,
132
):
133
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
134
    os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
135
136
137
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
138
139
140

    # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
    # engine processes.
141
142
143
144
145
146
147

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
148
    ] * 100
149
150
151
152

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
153
154
155
156
157
158
159
160
    floor = len(prompts) // dp_size
    remainder = len(prompts) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
161
162
163
164
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
165
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
166
167
168
169
170

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
171
172
173
    sampling_params = SamplingParams(
        temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
    )
174
175

    # Create an LLM.
176
177
178
179
    llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=enforce_eager,
180
        enable_expert_parallel=enable_expert_parallel,
181
        trust_remote_code=trust_remote_code,
182
        max_num_seqs=max_num_seqs,
183
        max_model_len=max_model_len,
184
        gpu_memory_utilization=gpu_memory_utilization,
185
        enable_dbo=enable_dbo,
186
        quantization=quantization,
187
        compilation_config=compilation_config,
188
    )
189
190
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
191
192
193
194
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
195
196
        prompt = output.prompt
        generated_text = output.outputs[0].text
197
198
199
200
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
201

202
203
204
    # Give engines time to pause their processing loops before exiting.
    sleep(1)

205
206

if __name__ == "__main__":
207
    args = parse_args()
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    dp_size = args.dp_size
    tp_size = args.tp_size
    node_size = args.node_size
    node_rank = args.node_rank

    if node_size == 1:
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
        dp_master_ip = args.master_addr
        dp_master_port = args.master_port

    assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
    dp_per_node = dp_size // node_size

224
    from multiprocessing import Process
225

226
227
228
229
230
    if current_platform.is_rocm():
        from multiprocessing import set_start_method

        set_start_method("spawn", force=True)

231
    procs = []
232
    for local_dp_rank, global_dp_rank in enumerate(
233
234
235
236
237
238
239
240
241
242
243
244
245
        range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
    ):
        proc = Process(
            target=main,
            args=(
                args.model,
                dp_size,
                local_dp_rank,
                global_dp_rank,
                dp_master_ip,
                dp_master_port,
                tp_size,
                args.enforce_eager,
246
                args.enable_expert_parallel,
247
                args.trust_remote_code,
248
                args.max_num_seqs,
249
                args.max_model_len,
250
                args.compilation_config,
251
                args.gpu_memory_utilization,
252
                args.enable_dbo,
253
                args.quantization,
254
255
            ),
        )
256
257
        proc.start()
        procs.append(proc)
258
    exit_code = 0
259
    for proc in procs:
260
        proc.join(timeout=args.timeout)
261
        if proc.exitcode is None:
262
            print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
263
264
265
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
266
267
268
            exit_code = proc.exitcode

    exit(exit_code)