data_parallel.py 7.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
            --dp-size=2 \
            --tp-size=2

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
31

32
import os
33
from time import sleep
34
35
36
37
38

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port


39
40
def parse_args():
    import argparse
41

42
    parser = argparse.ArgumentParser(description="Data Parallel Inference")
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    parser.add_argument(
        "--model",
        type=str,
        default="ibm-research/PowerMoE-3b",
        help="Model name or path",
    )
    parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
    parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
    parser.add_argument(
        "--node-size", type=int, default=1, help="Total number of nodes"
    )
    parser.add_argument(
        "--node-rank", type=int, default=0, help="Rank of the current node"
    )
    parser.add_argument(
        "--master-addr", type=str, default="", help="Master node IP address"
    )
    parser.add_argument("--master-port", type=int, default=0, help="Master node port")
    parser.add_argument(
        "--enforce-eager", action="store_true", help="Enforce eager mode execution."
    )
    parser.add_argument(
        "--trust-remote-code", action="store_true", help="Trust remote code."
    )
67
68
69
70
71
72
    parser.add_argument(
        "--max-num-seqs",
        type=int,
        default=64,
        help=("Maximum number of sequences to be processed in a single iteration."),
    )
73
74
75
76
77
78
79
80
81
82
83
    parser.add_argument(
        "--max-model-len",
        type=int,
        help=("Maximum number of tokens to be processed in a single iteration."),
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=300,
        help=("Number of seconds before unresponsive process is killed."),
    )
84
85
86
87
88
89
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.8,
        help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
    )
90
91
92
93
94
    parser.add_argument(
        "--enable-dbo",
        action="store_true",
        help=("Enable microbatched execution"),
    )
95
96
97
98
99
    parser.add_argument(
        "--compilation-config",
        type=int,
        help=("Compilation optimization (O) level 0-3."),
    )
100
101
102
103
    parser.add_argument(
        "--quantization",
        type=str,
    )
104
105
106
    return parser.parse_args()


107
108
109
110
111
112
113
114
115
116
def main(
    model,
    dp_size,
    local_dp_rank,
    global_dp_rank,
    dp_master_ip,
    dp_master_port,
    GPUs_per_dp_rank,
    enforce_eager,
    trust_remote_code,
117
    max_num_seqs,
118
    max_model_len,
119
    compilation_config,
120
    gpu_memory_utilization,
121
    enable_dbo,
122
    quantization,
123
):
124
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
125
    os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
126
127
128
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
129
130
131

    # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
    # engine processes.
132
133
134
135
136
137
138

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
139
    ] * 100
140
141
142
143

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
144
145
146
147
148
149
150
151
    floor = len(prompts) // dp_size
    remainder = len(prompts) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
152
153
154
155
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
156
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
157
158
159
160
161

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
162
163
164
    sampling_params = SamplingParams(
        temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
    )
165
166

    # Create an LLM.
167
168
169
170
171
172
    llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=enforce_eager,
        enable_expert_parallel=True,
        trust_remote_code=trust_remote_code,
173
        max_num_seqs=max_num_seqs,
174
        max_model_len=max_model_len,
175
        gpu_memory_utilization=gpu_memory_utilization,
176
        enable_dbo=enable_dbo,
177
        quantization=quantization,
178
        compilation_config=compilation_config,
179
    )
180
181
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
182
183
184
185
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
186
187
        prompt = output.prompt
        generated_text = output.outputs[0].text
188
189
190
191
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
192

193
194
195
    # Give engines time to pause their processing loops before exiting.
    sleep(1)

196
197

if __name__ == "__main__":
198
    args = parse_args()
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    dp_size = args.dp_size
    tp_size = args.tp_size
    node_size = args.node_size
    node_rank = args.node_rank

    if node_size == 1:
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
        dp_master_ip = args.master_addr
        dp_master_port = args.master_port

    assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
    dp_per_node = dp_size // node_size

215
    from multiprocessing import Process
216

217
    procs = []
218
    for local_dp_rank, global_dp_rank in enumerate(
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
    ):
        proc = Process(
            target=main,
            args=(
                args.model,
                dp_size,
                local_dp_rank,
                global_dp_rank,
                dp_master_ip,
                dp_master_port,
                tp_size,
                args.enforce_eager,
                args.trust_remote_code,
233
                args.max_num_seqs,
234
                args.max_model_len,
235
                args.compilation_config,
236
                args.gpu_memory_utilization,
237
                args.enable_dbo,
238
                args.quantization,
239
240
            ),
        )
241
242
        proc.start()
        procs.append(proc)
243
    exit_code = 0
244
    for proc in procs:
245
        proc.join(timeout=args.timeout)
246
        if proc.exitcode is None:
247
            print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
248
249
250
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
251
252
253
            exit_code = proc.exitcode

    exit(exit_code)