torchrun_dp_example.py 4.14 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
experimental support for data-parallel inference with torchrun
Note the data load balancing and distribution is done out of the vllm engine,
no internal lb supported in external_launcher mode.
7
8
9
10
11

To run this example:
```bash
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
```
12
13
14
15
16
17

With custom parallelism settings:
```bash
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
    --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
```
18
19
"""

20
21
import argparse

22
23
from vllm import LLM, SamplingParams

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

def parse_args():
    parser = argparse.ArgumentParser(
        description="Data-parallel inference with torchrun"
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size (default: 1)",
    )
    parser.add_argument(
        "--pp-size",
        type=int,
        default=1,
        help="Pipeline parallel size (default: 1)",
    )
    parser.add_argument(
        "--dp-size",
        type=int,
        default=2,
        help="Data parallel size (default: 2)",
    )
    parser.add_argument(
        "--enable-ep",
        action="store_true",
        help="Enable expert parallel (default: False)",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="microsoft/Phi-mini-MoE-instruct",
        help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
    )
    parser.add_argument(
        "--max-model-len",
        type=int,
        default=4096,
        help="Maximum model length (default: 4096)",
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.6,
        help="GPU memory utilization (default: 0.6)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Random seed (default: 1)",
    )
    return parser.parse_args()


args = parse_args()


82
83
84
85
86
87
# Create prompts, the same across all ranks
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
88
]
89
90
91
92
93
94
95
96
97
98

# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
99
100
101
102
103
    model=args.model,
    tensor_parallel_size=args.tp_size,
    data_parallel_size=args.dp_size,
    pipeline_parallel_size=args.pp_size,
    enable_expert_parallel=args.enable_ep,
104
    distributed_executor_backend="external_launcher",
105
106
107
    max_model_len=args.max_model_len,
    gpu_memory_utilization=args.gpu_memory_utilization,
    seed=args.seed,
108
109
110
111
112
113
114
115
116
117
118
119
120
121
)

dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size

prompts = [
    f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
122
123
124
125
    print(
        f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
    )

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Further tips:

1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.

```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
    # do something for rank 0, e.g. saving the results to disk.
```

2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```

3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""