data_parallel.py 5.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Usage:
Single node:
    python examples/offline_inference/data_parallel.py \
            --model="ibm-research/PowerMoE-3b" \
            --dp-size=2 \
            --tp-size=2

Multi-node:
    Node 0 (assume the node has ip of 10.99.48.128):
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=0 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
    Node 1:
            python examples/offline_inference/data_parallel.py \
                    --model="ibm-research/PowerMoE-3b" \
                    --dp-size=2 \
                    --tp-size=2 \
                    --node-size=2 \
                    --node-rank=1 \
                    --master-addr=10.99.48.128 \
                    --master-port=13345
"""
30
31
32
33
34
35
import os

from vllm import LLM, SamplingParams
from vllm.utils import get_open_port


36
37
38
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
         dp_master_port, GPUs_per_dp_rank):
    os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
39
40
41
42
43
    os.environ["VLLM_DP_SIZE"] = str(dp_size)
    os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
    os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
    # set devices for each dp_rank
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
44
45
46
        str(i)
        for i in range(local_dp_rank * GPUs_per_dp_rank, (local_dp_rank + 1) *
                       GPUs_per_dp_rank))
47
48
49
50
51
52
53

    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
54
    ] * 100
55
56
57
58
59

    # with DP, each rank should process different prompts.
    # usually all the DP ranks process a full dataset,
    # and each rank processes a different part of the dataset.
    promts_per_rank = len(prompts) // dp_size
60
    start = global_dp_rank * promts_per_rank
61
62
63
64
65
66
    end = start + promts_per_rank
    prompts = prompts[start:end]
    if len(prompts) == 0:
        # if any rank has no prompts to process,
        # we need to set a placeholder prompt
        prompts = ["Placeholder"]
67
    print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
68
69
70
71
72
73
74

    # Create a sampling params object.
    # since we are doing data parallel, every rank can have different
    # sampling params. here we set different max_tokens for different
    # ranks for demonstration.
    sampling_params = SamplingParams(temperature=0.8,
                                     top_p=0.95,
75
                                     max_tokens=[16, 20][global_dp_rank % 2])
76
77

    # Create an LLM.
78
    llm = LLM(model=model,
79
              tensor_parallel_size=GPUs_per_dp_rank,
80
81
              enforce_eager=True,
              enable_expert_parallel=True)
82
83
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
84
85
86
87
    for i, output in enumerate(outputs):
        if i >= 5:
            # print only 5 outputs
            break
88
89
        prompt = output.prompt
        generated_text = output.outputs[0].text
90
        print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
youkaichao's avatar
youkaichao committed
91
              f"Generated text: {generated_text!r}")
92
93
94


if __name__ == "__main__":
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    import argparse
    parser = argparse.ArgumentParser(description="Data Parallel Inference")
    parser.add_argument("--model",
                        type=str,
                        default="ibm-research/PowerMoE-3b",
                        help="Model name or path")
    parser.add_argument("--dp-size",
                        type=int,
                        default=2,
                        help="Data parallel size")
    parser.add_argument("--tp-size",
                        type=int,
                        default=2,
                        help="Tensor parallel size")
    parser.add_argument("--node-size",
                        type=int,
                        default=1,
                        help="Total number of nodes")
    parser.add_argument("--node-rank",
                        type=int,
                        default=0,
                        help="Rank of the current node")
    parser.add_argument("--master-addr",
                        type=str,
                        default="",
                        help="Master node IP address")
    parser.add_argument("--master-port",
                        type=int,
                        default=0,
                        help="Master node port")
    args = parser.parse_args()

    dp_size = args.dp_size
    tp_size = args.tp_size
    node_size = args.node_size
    node_rank = args.node_rank

    if node_size == 1:
        dp_master_ip = "127.0.0.1"
        dp_master_port = get_open_port()
    else:
        dp_master_ip = args.master_addr
        dp_master_port = args.master_port

    assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
    dp_per_node = dp_size // node_size

142
    from multiprocessing import Process
143

144
    procs = []
145
146
    for local_dp_rank, global_dp_rank in enumerate(
            range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
147
        proc = Process(target=main,
148
149
150
                       args=(args.model, dp_size, local_dp_rank,
                             global_dp_rank, dp_master_ip, dp_master_port,
                             tp_size))
151
152
        proc.start()
        procs.append(proc)
153
    exit_code = 0
154
155
    for proc in procs:
        proc.join()
156
157
158
159
        if proc.exitcode:
            exit_code = proc.exitcode

    exit(exit_code)