data_parallel.py 6.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
            --dp-size=2 \
            --tp-size=2

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
30
import os
31
from time import sleep
32
33
34
35
36

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Data Parallel Inference")
    parser.add_argument("--model",
                        type=str,
                        default="ibm-research/PowerMoE-3b",
                        help="Model name or path")
    parser.add_argument("--dp-size",
                        type=int,
                        default=2,
                        help="Data parallel size")
    parser.add_argument("--tp-size",
                        type=int,
                        default=2,
                        help="Tensor parallel size")
    parser.add_argument("--node-size",
                        type=int,
                        default=1,
                        help="Total number of nodes")
    parser.add_argument("--node-rank",
                        type=int,
                        default=0,
                        help="Rank of the current node")
    parser.add_argument("--master-addr",
                        type=str,
                        default="",
                        help="Master node IP address")
    parser.add_argument("--master-port",
                        type=int,
                        default=0,
                        help="Master node port")
68
69
70
71
72
73
    parser.add_argument("--enforce-eager",
                        action='store_true',
                        help="Enforce eager mode execution.")
    parser.add_argument("--trust-remote-code",
                        action='store_true',
                        help="Trust remote code.")
74
75
76
    return parser.parse_args()


77
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
78
         dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code):
79
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
80
    os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
81
82
83
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
84
85
86

    # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
    # engine processes.
87
88
89
90
91
92
93

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
94
    ] * 100
95
96
97
98
99

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    promts_per_rank = len(prompts) // dp_size
100
    start = global_dp_rank * promts_per_rank
101
102
103
104
105
106
    end = start + promts_per_rank
    prompts = prompts[start:end]
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
107
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
108
109
110
111
112
113
114

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
    sampling_params = SamplingParams(temperature=0.8,
                                     top_p=0.95,
115
                                     max_tokens=[16, 20][global_dp_rank % 2])
116
117

    # Create an LLM.
118
119
120
121
122
123
124
    llm = LLM(
        model=model,
        tensor_parallel_size=GPUs_per_dp_rank,
        enforce_eager=enforce_eager,
        enable_expert_parallel=True,
        trust_remote_code=trust_remote_code,
    )
125
126
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
127
128
129
130
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
131
132
        prompt = output.prompt
        generated_text = output.outputs[0].text
133
        print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
youkaichao's avatar
youkaichao committed
134
              f"Generated text: {generated_text!r}")
135

136
137
138
    # Give engines time to pause their processing loops before exiting.
    sleep(1)

139
140

if __name__ == "__main__":
141
142

    args = parse_args()
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    dp_size = args.dp_size
    tp_size = args.tp_size
    node_size = args.node_size
    node_rank = args.node_rank

    if node_size == 1:
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
        dp_master_ip = args.master_addr
        dp_master_port = args.master_port

    assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
    dp_per_node = dp_size // node_size

159
    from multiprocessing import Process
160

161
    procs = []
162
163
    for local_dp_rank, global_dp_rank in enumerate(
            range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
164
        proc = Process(target=main,
165
166
                       args=(args.model, dp_size, local_dp_rank,
                             global_dp_rank, dp_master_ip, dp_master_port,
167
168
                             tp_size, args.enforce_eager,
                             args.trust_remote_code))
169
170
        proc.start()
        procs.append(proc)
171
    exit_code = 0
172
    for proc in procs:
173
174
175
176
177
178
179
        proc.join(timeout=300)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that "
                  f"didn't stop within 5 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
180
181
182
            exit_code = proc.exitcode

    exit(exit_code)