test_pynccl.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import multiprocessing
5
import os
6

7
import numpy as np
8
9
import pytest
import torch
10
import torch.distributed
11

12
from vllm.distributed.communication_op import (  # noqa
13
    tensor_model_parallel_all_reduce)
14
15
16
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
17
                                             get_world_group, graph_capture,
18
                                             init_distributed_environment)
19
from vllm.utils import update_environment_variables
20
21
22
23


def distributed_run(fn, world_size):
    number_of_processes = world_size
24
    processes: list[multiprocessing.Process] = []
25
    for i in range(number_of_processes):
26
        env: dict[str, str] = {}
27
        env['RANK'] = str(i)
28
        env['LOCAL_RANK'] = str(i)
29
        env['WORLD_SIZE'] = str(number_of_processes)
30
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
31
32
33
34
35
36
37
38
39
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

40
41
42
    for p in processes:
        assert p.exitcode == 0

43

44
def worker_fn_wrapper(fn):
45
46
47
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
48
    def wrapped_fn(env):
49
        update_environment_variables(env)
50
51
52
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
53
        init_distributed_environment()
54
55
        fn()

56
    return wrapped_fn
57
58


59
@worker_fn_wrapper
60
def worker_fn():
61
62
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
63
64
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
65
    tensor = pynccl_comm.all_reduce(tensor)
66
    torch.cuda.synchronize()
67
    assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
68
69
70
71
72
73
74
75


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


76
@worker_fn_wrapper
77
def multiple_allreduce_worker_fn():
78
79
80
81
82
83
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
84
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
85
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
86
87
88
89
90
91
92
93
94
95
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        tensor = pynccl_comm.all_reduce(tensor)
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 4).cpu().item()
    else:
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 2).cpu().item()
96
97
98


@pytest.mark.skipif(torch.cuda.device_count() < 4,
99
                    reason="Need at least 4 GPUs to run the test.")
100
def test_pynccl_multiple_allreduce():
101
    # this tests pynccl for multiple tp groups, in a standalone way
102
    # i.e. call `pynccl_comm.all_reduce` directly
103
    distributed_run(multiple_allreduce_worker_fn, 4)
104
105
106


@worker_fn_wrapper
107
def multiple_allreduce_with_vllm_worker_fn():
108
109
110
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
111
    with graph_capture(device=device):
112
113
114
115
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
116
            torch.cuda.synchronize()
117
            assert torch.all(tensor == 4).cpu().item()
118
119
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
120
            torch.cuda.synchronize()
121
            assert torch.all(tensor == 2).cpu().item()
122
123
124
125


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
126
def test_pynccl_multiple_allreduce_with_vllm():
127
128
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
129
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
130
131


132
@worker_fn_wrapper
133
134
135
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
136
137
        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                         device=get_world_group().device)
138
        # run something in the default stream to initialize torch engine
139
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
140
        torch.cuda.synchronize()
141
        with torch.cuda.graph(graph):
142
            a_out = pynccl_comm.all_reduce(a)
143
        torch.cuda.synchronize()
144
        graph.replay()
145
        torch.cuda.synchronize()
146
        assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
147
148


149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
@worker_fn_wrapper
def all_gather_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    result = torch.zeros(num_elems * world_size,
                         dtype=torch.float32,
                         device=device)

    expected = torch.cat([
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]).to(device)

170
    pynccl_comm.all_gather(result, tensor)
171
    torch.cuda.synchronize()
172
173
174
175
176
177
178
179
180
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@worker_fn_wrapper
def all_gatherv_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sizes[rank]
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * 100
    result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)

    expected = torch.cat([
        torch.arange(sizes[r], dtype=torch.float32) + r * 100
        for r in range(world_size)
    ]).to(device)

    pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
    distributed_run(all_gatherv_worker_fn, 2)


213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
@worker_fn_wrapper
def reduce_scatter_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    assert (num_elems % world_size == 0)
    result = torch.zeros(num_elems // world_size,
                         dtype=torch.float32,
                         device=device)

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
    expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
                   for tensor in all_tensors).to(device)

239
    pynccl_comm.reduce_scatter(result, tensor)
240
    torch.cuda.synchronize()
241
242
243
244
245
246
247
248
249
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sum(sizes)
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * 100
    result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)

    # Calculate expected result for this rank's chunk
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * 100
        for r in range(world_size)
    ]
    sizes_cumsum = np.cumsum(sizes)
    start = 0 if rank == 0 else sizes_cumsum[rank - 1]
    end = sizes_cumsum[rank]
    expected = sum(tensor[start:end] for tensor in all_tensors).to(device)

    pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
    distributed_run(reduce_scatterv_worker_fn, 2)


287
288
289
290
291
292
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


293
294
@worker_fn_wrapper
def send_recv_worker_fn():
295
296
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
297
298
299
300
301
302
    if pynccl_comm.rank == 0:
        tensor = torch.ones(16, 1024, 1024,
                            dtype=torch.float32).cuda(pynccl_comm.rank)
    else:
        tensor = torch.empty(16, 1024, 1024,
                             dtype=torch.float32).cuda(pynccl_comm.rank)
303
304
305
306
307
308
309

    if pynccl_comm.rank == 0:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
310
    torch.cuda.synchronize()
311
    assert torch.all(tensor == 1).cpu().item()
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
        tensor = 2 * torch.ones(
            16, 1024, 1024, dtype=torch.float32, device=device)
    else:
        tensor = torch.empty(16,
                             1024,
                             1024,
                             dtype=torch.float32,
                             device=device)
340
341
342
343
344
345
    if torch.distributed.get_rank() in [0, 1]:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
346
    torch.cuda.synchronize()
347
    if torch.distributed.get_rank() in [0, 2]:
348
        assert torch.all(tensor == 1).cpu().item()
349
    else:
350
        assert torch.all(tensor == 2).cpu().item()
351
352
353
354
355
356
357
358


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
    recv_tensors = [
        torch.empty(16,
                    1024,
                    1024,
                    dtype=torch.float32,
                    device=pynccl_comm.device)
        for i in range(pynccl_comm.world_size)
    ]
    recv_tensors[pynccl_comm.rank] = torch.ones(
        16, 1024, 1024, dtype=torch.float32,
        device=pynccl_comm.device) * pynccl_comm.rank

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


391
def test_ncclGetUniqueId():
392
393
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
394
395
396
397
398
399
400
401
402
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None