data_parallel.py 5.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
8
9
            -dp=2 \
            -tp=2
10
11
12
13
14

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
15
16
17
                    -dp=2 \
                    -tp=2 \
                    --nnodes=2 \
18
19
20
21
22
23
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
24
25
26
                    -dp=2 \
                    -tp=2 \
                    --nnodes=2 \
27
28
29
30
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
31

32
import os
33
from time import sleep
34

35
from vllm import LLM, EngineArgs, SamplingParams
36
from vllm.platforms import current_platform
37
from vllm.utils.argparse_utils import FlexibleArgumentParser
38
from vllm.utils.network_utils import get_open_port
39
40


41
42
def create_parser():
    parser = FlexibleArgumentParser(description="Data Parallel Inference")
43

44
45
46
47
48
    # Add all engine args
    EngineArgs.add_cli_args(parser)
    parser.set_defaults(
        model="ibm-research/PowerMoE-3b",
        enable_expert_parallel=True,
49
    )
50
51

    # Add timeout (not in EngineArgs)
52
53
54
55
    parser.add_argument(
        "--timeout",
        type=int,
        default=300,
56
        help="Number of seconds before unresponsive process is killed.",
57
    )
58
59

    return parser
60
61


62
63
64
65
66
67
def main(
    dp_size,
    local_dp_rank,
    global_dp_rank,
    dp_master_ip,
    dp_master_port,
68
    engine_args,
69
):
70
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
71
    os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
72
73
74
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
75
76
77

    # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
    # engine processes.
78
79
80
81
82
83
84

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
85
    ] * 100
86
87
88
89

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
90
91
92
93
94
95
96
97
    floor = len(prompts) // dp_size
    remainder = len(prompts) % dp_size

    # Distribute prompts into even groups.
    def start(rank):
        return rank * floor + min(rank, remainder)

    prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
98
99
100
101
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
102
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
103
104
105
106
107

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
108
109
110
    sampling_params = SamplingParams(
        temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
    )
111
112

    # Create an LLM.
113
    llm = LLM(**engine_args)
114
115
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
116
117
118
119
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
120
121
        prompt = output.prompt
        generated_text = output.outputs[0].text
122
123
124
125
        print(
            f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
            f"Generated text: {generated_text!r}"
        )
126

127
128
129
    # Give engines time to pause their processing loops before exiting.
    sleep(1)

130
131

if __name__ == "__main__":
132
133
134
135
136
137
138
139
140
141
    parser = create_parser()
    args = vars(parser.parse_args())

    # Extract DP-specific args
    dp_size = args.pop("data_parallel_size")
    nnodes = args.get("nnodes", 1)
    node_rank = args.get("node_rank", 0)
    master_addr = args.get("master_addr", "")
    master_port = args.get("master_port", 0)
    timeout = args.pop("timeout")
142

143
144
    # Remaining args are engine args
    engine_args = args
145

146
    if nnodes == 1:
147
148
149
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
150
151
        dp_master_ip = master_addr
        dp_master_port = master_port
152

153
154
    assert dp_size % nnodes == 0, "dp_size should be divisible by nnodes"
    dp_per_node = dp_size // nnodes
155

156
    from multiprocessing import Process
157

158
159
160
161
162
    if current_platform.is_rocm():
        from multiprocessing import set_start_method

        set_start_method("spawn", force=True)

163
    procs = []
164
    for local_dp_rank, global_dp_rank in enumerate(
165
166
167
168
169
170
171
172
173
174
        range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
    ):
        proc = Process(
            target=main,
            args=(
                dp_size,
                local_dp_rank,
                global_dp_rank,
                dp_master_ip,
                dp_master_port,
175
                engine_args,
176
177
            ),
        )
178
179
        proc.start()
        procs.append(proc)
180
    exit_code = 0
181
    for proc in procs:
182
        proc.join(timeout=timeout)
183
        if proc.exitcode is None:
184
            print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
185
186
187
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
188
189
190
            exit_code = proc.exitcode

    exit(exit_code)