test_pynccl.py 8.8 KB
Newer Older
1
import multiprocessing
2
import os
3
4
5

import pytest
import torch
6
import torch.distributed
7

8
from vllm.distributed.communication_op import (  # noqa
9
    tensor_model_parallel_all_reduce)
10
11
12
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
13
                                             get_world_group, graph_capture,
14
                                             init_distributed_environment)
15
from vllm.utils import update_environment_variables
16
17
18
19
20
21


def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
22
        env = {}
23
        env['RANK'] = str(i)
24
        env['LOCAL_RANK'] = str(i)
25
        env['WORLD_SIZE'] = str(number_of_processes)
26
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
27
28
29
30
31
32
33
34
35
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

36
37
38
    for p in processes:
        assert p.exitcode == 0

39

40
def worker_fn_wrapper(fn):
41
42
43
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
44
    def wrapped_fn(env):
45
        update_environment_variables(env)
46
47
48
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
49
        init_distributed_environment()
50
51
        fn()

52
    return wrapped_fn
53
54


55
@worker_fn_wrapper
56
def worker_fn():
57
58
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
59
60
61
62
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
    with pynccl_comm.change_state(enable=True):
        pynccl_comm.all_reduce(tensor)
63
    result = tensor.mean().cpu().item()
64
    assert result == pynccl_comm.world_size
65
66
67
68
69
70
71
72


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


73
@worker_fn_wrapper
74
def multiple_allreduce_worker_fn():
75
76
77
78
79
80
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
81
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
82
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
83
84
85
86
87
88
89
90
91
92
93
    with pynccl_comm.change_state(enable=True):
        # two groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            pynccl_comm.all_reduce(tensor)
            pynccl_comm.all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            pynccl_comm.all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 2
94
95
96


@pytest.mark.skipif(torch.cuda.device_count() < 4,
97
                    reason="Need at least 4 GPUs to run the test.")
98
def test_pynccl_multiple_allreduce():
99
    # this tests pynccl for multiple tp groups, in a standalone way
100
    # i.e. call `pynccl_comm.all_reduce` directly
101
    distributed_run(multiple_allreduce_worker_fn, 4)
102
103
104


@worker_fn_wrapper
105
def multiple_allreduce_with_vllm_worker_fn():
106
107
108
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
109
    with graph_capture():
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
124
def test_pynccl_multiple_allreduce_with_vllm():
125
126
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
127
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
128
129


130
@worker_fn_wrapper
131
132
133
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
134
135
        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                         device=get_world_group().device)
136
        # run something in the default stream to initialize torch engine
137
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
138
        torch.cuda.synchronize()
139
140
141
        with torch.cuda.graph(
                graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
                    enable=True):
142
143
            # operation during the graph capture is recorded but not executed
            # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
144
145
146
            pynccl_comm.all_reduce(a)
        pynccl_comm.stream.synchronize()
        assert a.mean().cpu().item() == pynccl_comm.world_size**0
147
        graph.replay()
148
149
        pynccl_comm.stream.synchronize()
        assert a.mean().cpu().item() == pynccl_comm.world_size**1
150
151
152
153
154
155
156
157


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


158
159
@worker_fn_wrapper
def send_recv_worker_fn():
160
161
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    if pynccl_comm.rank == 0:
        tensor = torch.ones(16, 1024, 1024,
                            dtype=torch.float32).cuda(pynccl_comm.rank)
    else:
        tensor = torch.empty(16, 1024, 1024,
                             dtype=torch.float32).cuda(pynccl_comm.rank)
    with pynccl_comm.change_state(enable=True):
        if pynccl_comm.rank == 0:
            pynccl_comm.send(tensor)
        else:
            pynccl_comm.recv(tensor)
    result = tensor.mean().cpu().item()
    assert result == 1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
        tensor = 2 * torch.ones(
            16, 1024, 1024, dtype=torch.float32, device=device)
    else:
        tensor = torch.empty(16,
                             1024,
                             1024,
                             dtype=torch.float32,
                             device=device)
    with pynccl_comm.change_state(enable=True):
        if torch.distributed.get_rank() in [0, 1]:
            pynccl_comm.send(tensor)
        else:
            pynccl_comm.recv(tensor)
    result = tensor.mean().cpu().item()
    if torch.distributed.get_rank() in [0, 2]:
        assert result == 1
    else:
        assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


221
def test_ncclGetUniqueId():
222
223
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
224
225
226
227
228
229
230
231
232
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None