data_parallel.py 5.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
            --dp-size=2 \
            --tp-size=2

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
30
import os
31
from time import sleep
32
33
34
35
36

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="Data Parallel Inference")
    parser.add_argument("--model",
                        type=str,
                        default="ibm-research/PowerMoE-3b",
                        help="Model name or path")
    parser.add_argument("--dp-size",
                        type=int,
                        default=2,
                        help="Data parallel size")
    parser.add_argument("--tp-size",
                        type=int,
                        default=2,
                        help="Tensor parallel size")
    parser.add_argument("--node-size",
                        type=int,
                        default=1,
                        help="Total number of nodes")
    parser.add_argument("--node-rank",
                        type=int,
                        default=0,
                        help="Rank of the current node")
    parser.add_argument("--master-addr",
                        type=str,
                        default="",
                        help="Master node IP address")
    parser.add_argument("--master-port",
                        type=int,
                        default=0,
                        help="Master node port")
    return parser.parse_args()


71
72
73
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
         dp_master_port, GPUs_per_dp_rank):
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
74
    os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
75
76
77
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
78
79
80

    # CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
    # engine processes.
81
82
83
84
85
86
87

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
88
    ] * 100
89
90
91
92
93

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    promts_per_rank = len(prompts) // dp_size
94
    start = global_dp_rank * promts_per_rank
95
96
97
98
99
100
    end = start + promts_per_rank
    prompts = prompts[start:end]
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
101
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
102
103
104
105
106
107
108

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
    sampling_params = SamplingParams(temperature=0.8,
                                     top_p=0.95,
109
                                     max_tokens=[16, 20][global_dp_rank % 2])
110
111

    # Create an LLM.
112
    llm = LLM(model=model,
113
              tensor_parallel_size=GPUs_per_dp_rank,
114
115
              enforce_eager=True,
              enable_expert_parallel=True)
116
117
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
118
119
120
121
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
122
123
        prompt = output.prompt
        generated_text = output.outputs[0].text
124
        print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
youkaichao's avatar
youkaichao committed
125
              f"Generated text: {generated_text!r}")
126

127
128
129
    # Give engines time to pause their processing loops before exiting.
    sleep(1)

130
131

if __name__ == "__main__":
132
133

    args = parse_args()
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    dp_size = args.dp_size
    tp_size = args.tp_size
    node_size = args.node_size
    node_rank = args.node_rank

    if node_size == 1:
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
        dp_master_ip = args.master_addr
        dp_master_port = args.master_port

    assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
    dp_per_node = dp_size // node_size

150
    from multiprocessing import Process
151

152
    procs = []
153
154
    for local_dp_rank, global_dp_rank in enumerate(
            range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
155
        proc = Process(target=main,
156
157
158
                       args=(args.model, dp_size, local_dp_rank,
                             global_dp_rank, dp_master_ip, dp_master_port,
                             tp_size))
159
160
        proc.start()
        procs.append(proc)
161
    exit_code = 0
162
    for proc in procs:
163
164
165
166
167
168
169
        proc.join(timeout=300)
        if proc.exitcode is None:
            print(f"Killing process {proc.pid} that "
                  f"didn't stop within 5 minutes.")
            proc.kill()
            exit_code = 1
        elif proc.exitcode:
170
171
172
            exit_code = proc.exitcode

    exit(exit_code)