test_pynccl.py 12.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import multiprocessing
4
import os
5
6
7

import pytest
import torch
8
import torch.distributed
9

10
from vllm.distributed.communication_op import (  # noqa
11
    tensor_model_parallel_all_reduce)
12
13
14
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
15
                                             get_world_group, graph_capture,
16
                                             init_distributed_environment)
17
from vllm.utils import update_environment_variables
18
19
20
21


def distributed_run(fn, world_size):
    number_of_processes = world_size
22
    processes: list[multiprocessing.Process] = []
23
    for i in range(number_of_processes):
24
        env: dict[str, str] = {}
25
        env['RANK'] = str(i)
26
        env['LOCAL_RANK'] = str(i)
27
        env['WORLD_SIZE'] = str(number_of_processes)
28
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
29
30
31
32
33
34
35
36
37
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

38
39
40
    for p in processes:
        assert p.exitcode == 0

41

42
def worker_fn_wrapper(fn):
43
44
45
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
46
    def wrapped_fn(env):
47
        update_environment_variables(env)
48
49
50
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
51
        init_distributed_environment()
52
53
        fn()

54
    return wrapped_fn
55
56


57
@worker_fn_wrapper
58
def worker_fn():
59
60
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
61
62
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
63
    tensor = pynccl_comm.all_reduce(tensor)
64
    torch.cuda.synchronize()
65
    assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
66
67
68
69
70
71
72
73


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


74
@worker_fn_wrapper
75
def multiple_allreduce_worker_fn():
76
77
78
79
80
81
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
82
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
83
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
84
85
86
87
88
89
90
91
92
93
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        tensor = pynccl_comm.all_reduce(tensor)
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 4).cpu().item()
    else:
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 2).cpu().item()
94
95
96


@pytest.mark.skipif(torch.cuda.device_count() < 4,
97
                    reason="Need at least 4 GPUs to run the test.")
98
def test_pynccl_multiple_allreduce():
99
    # this tests pynccl for multiple tp groups, in a standalone way
100
    # i.e. call `pynccl_comm.all_reduce` directly
101
    distributed_run(multiple_allreduce_worker_fn, 4)
102
103
104


@worker_fn_wrapper
105
def multiple_allreduce_with_vllm_worker_fn():
106
107
108
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
109
    with graph_capture(device=device):
110
111
112
113
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
114
            torch.cuda.synchronize()
115
            assert torch.all(tensor == 4).cpu().item()
116
117
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
118
            torch.cuda.synchronize()
119
            assert torch.all(tensor == 2).cpu().item()
120
121
122
123


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
124
def test_pynccl_multiple_allreduce_with_vllm():
125
126
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
127
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
128
129


130
@worker_fn_wrapper
131
132
133
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
134
135
        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                         device=get_world_group().device)
136
        # run something in the default stream to initialize torch engine
137
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
138
        torch.cuda.synchronize()
139
        with torch.cuda.graph(graph):
140
            a_out = pynccl_comm.all_reduce(a)
141
        torch.cuda.synchronize()
142
        graph.replay()
143
        torch.cuda.synchronize()
144
        assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
@worker_fn_wrapper
def all_gather_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    result = torch.zeros(num_elems * world_size,
                         dtype=torch.float32,
                         device=device)

    expected = torch.cat([
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]).to(device)

168
    pynccl_comm.all_gather(result, tensor)
169
    torch.cuda.synchronize()
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    assert (num_elems % world_size == 0)
    result = torch.zeros(num_elems // world_size,
                         dtype=torch.float32,
                         device=device)

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
    expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
                   for tensor in all_tensors).to(device)

205
    pynccl_comm.reduce_scatter(result, tensor)
206
    torch.cuda.synchronize()
207
208
209
210
211
212
213
214
215
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


216
217
218
219
220
221
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


222
223
@worker_fn_wrapper
def send_recv_worker_fn():
224
225
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
226
227
228
229
230
231
    if pynccl_comm.rank == 0:
        tensor = torch.ones(16, 1024, 1024,
                            dtype=torch.float32).cuda(pynccl_comm.rank)
    else:
        tensor = torch.empty(16, 1024, 1024,
                             dtype=torch.float32).cuda(pynccl_comm.rank)
232
233
234
235
236
237
238

    if pynccl_comm.rank == 0:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
239
    torch.cuda.synchronize()
240
    assert torch.all(tensor == 1).cpu().item()
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
        tensor = 2 * torch.ones(
            16, 1024, 1024, dtype=torch.float32, device=device)
    else:
        tensor = torch.empty(16,
                             1024,
                             1024,
                             dtype=torch.float32,
                             device=device)
269
270
271
272
273
274
    if torch.distributed.get_rank() in [0, 1]:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
275
    torch.cuda.synchronize()
276
    if torch.distributed.get_rank() in [0, 2]:
277
        assert torch.all(tensor == 1).cpu().item()
278
    else:
279
        assert torch.all(tensor == 2).cpu().item()
280
281
282
283
284
285
286
287


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
    recv_tensors = [
        torch.empty(16,
                    1024,
                    1024,
                    dtype=torch.float32,
                    device=pynccl_comm.device)
        for i in range(pynccl_comm.world_size)
    ]
    recv_tensors[pynccl_comm.rank] = torch.ones(
        16, 1024, 1024, dtype=torch.float32,
        device=pynccl_comm.device) * pynccl_comm.rank

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


320
def test_ncclGetUniqueId():
321
322
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
323
324
325
326
327
328
329
330
331
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None