test_pynccl.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5

6
import multiprocess as mp
7
import numpy as np
8
9
import pytest
import torch
10
import torch.distributed
11

12
from tests.utils import ensure_current_vllm_config
13
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce  # noqa
14
15
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
16
17
18
19
20
21
from vllm.distributed.parallel_state import (
    ensure_model_parallel_initialized,
    get_world_group,
    graph_capture,
    init_distributed_environment,
)
22
from vllm.utils.system_utils import update_environment_variables
23

24
25
mp.set_start_method("spawn", force=True)

26
27
28

def distributed_run(fn, world_size):
    number_of_processes = world_size
29
    processes: list[mp.Process] = []
30
    for i in range(number_of_processes):
31
        env: dict[str, str] = {}
32
33
34
35
36
37
        env["RANK"] = str(i)
        env["LOCAL_RANK"] = str(i)
        env["WORLD_SIZE"] = str(number_of_processes)
        env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
        env["MASTER_ADDR"] = "localhost"
        env["MASTER_PORT"] = "12345"
38
        p = mp.Process(target=fn, args=(env,))
39
40
41
42
43
44
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

45
46
47
    for p in processes:
        assert p.exitcode == 0

48

49
def worker_fn_wrapper(fn):
50
51
52
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
53
    def wrapped_fn(env):
54
        update_environment_variables(env)
55
        local_rank = os.environ["LOCAL_RANK"]
56
57
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
58
        init_distributed_environment()
59
60
        fn()

61
    return wrapped_fn
62
63


64
@worker_fn_wrapper
65
def worker_fn():
66
67
68
69
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
70
    tensor = pynccl_comm.all_reduce(tensor)
71
    torch.cuda.synchronize()
72
    assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
73
74


75
76
77
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
78
79
80
81
def test_pynccl():
    distributed_run(worker_fn, 2)


82
@worker_fn_wrapper
83
def multiple_allreduce_worker_fn():
84
85
86
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
87
        torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
88
89
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
90
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
91
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
92
93
94
95
96
97
98
99
100
101
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        tensor = pynccl_comm.all_reduce(tensor)
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 4).cpu().item()
    else:
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 2).cpu().item()
102
103


104
105
106
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
107
def test_pynccl_multiple_allreduce():
108
    # this tests pynccl for multiple tp groups, in a standalone way
109
    # i.e. call `pynccl_comm.all_reduce` directly
110
    distributed_run(multiple_allreduce_worker_fn, 4)
111
112
113


@worker_fn_wrapper
114
def multiple_allreduce_with_vllm_worker_fn():
115
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
116
117
    with ensure_current_vllm_config():
        ensure_model_parallel_initialized(2, 2)
118
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
119
    with graph_capture(device=device):
120
121
122
123
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
124
            torch.cuda.synchronize()
125
            assert torch.all(tensor == 4).cpu().item()
126
127
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
128
            torch.cuda.synchronize()
129
            assert torch.all(tensor == 2).cpu().item()
130
131


132
133
134
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
135
def test_pynccl_multiple_allreduce_with_vllm():
136
137
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
138
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
139
140


141
@worker_fn_wrapper
142
143
144
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
145
146
147
        pynccl_comm = PyNcclCommunicator(
            get_world_group().cpu_group, device=get_world_group().device
        )
148
        # run something in the default stream to initialize torch engine
149
        a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
150
        torch.cuda.synchronize()
151
        with torch.cuda.graph(graph):
152
            a_out = pynccl_comm.all_reduce(a)
153
        torch.cuda.synchronize()
154
        graph.replay()
155
        torch.cuda.synchronize()
156
        assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
157
158


159
160
@worker_fn_wrapper
def all_gather_worker_fn():
161
162
163
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
164
165
166

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
167
    device = f"cuda:{pynccl_comm.rank}"
168
169

    num_elems = 1000
170
171
172
173
174
175
176
177
178
179
180
    tensor = (
        torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
    )
    result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)

    expected = torch.cat(
        [
            torch.arange(num_elems, dtype=torch.float32) + r * num_elems
            for r in range(world_size)
        ]
    ).to(device)
181

182
    pynccl_comm.all_gather(result, tensor)
183
    torch.cuda.synchronize()
184
185
186
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


187
188
189
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
190
191
192
193
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


194
195
@worker_fn_wrapper
def all_gatherv_worker_fn():
196
197
198
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
199
200
201

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
202
    device = f"cuda:{pynccl_comm.rank}"
203
204
205
206

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sizes[rank]
207
    tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
208
209
    result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)

210
211
212
213
214
215
    expected = torch.cat(
        [
            torch.arange(sizes[r], dtype=torch.float32) + r * 100
            for r in range(world_size)
        ]
    ).to(device)
216
217
218
219
220
221

    pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


222
223
224
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
225
226
227
228
def test_pynccl_all_gatherv():
    distributed_run(all_gatherv_worker_fn, 2)


229
230
@worker_fn_wrapper
def reduce_scatter_worker_fn():
231
232
233
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
234
235
236

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
237
    device = f"cuda:{pynccl_comm.rank}"
238
239

    num_elems = 1000
240
241
242
243
244
    tensor = (
        torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
    )
    assert num_elems % world_size == 0
    result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
245
246
247
248
249
250
251

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
252
253
254
255
    expected = sum(
        tensor[rank * scattered_size : (rank + 1) * scattered_size]
        for tensor in all_tensors
    ).to(device)
256

257
    pynccl_comm.reduce_scatter(result, tensor)
258
    torch.cuda.synchronize()
259
260
261
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


262
263
264
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
265
266
267
268
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


269
270
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
271
272
273
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
274
275
276

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
277
    device = f"cuda:{pynccl_comm.rank}"
278
279
280
281

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sum(sizes)
282
    tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)

    # Calculate expected result for this rank's chunk
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * 100
        for r in range(world_size)
    ]
    sizes_cumsum = np.cumsum(sizes)
    start = 0 if rank == 0 else sizes_cumsum[rank - 1]
    end = sizes_cumsum[rank]
    expected = sum(tensor[start:end] for tensor in all_tensors).to(device)

    pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


300
301
302
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
303
304
305
306
def test_pynccl_reduce_scatterv():
    distributed_run(reduce_scatterv_worker_fn, 2)


307
308
309
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
310
311
312
313
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


314
315
@worker_fn_wrapper
def send_recv_worker_fn():
316
317
318
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
319
    if pynccl_comm.rank == 0:
320
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
321
    else:
322
        tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
323
324

    if pynccl_comm.rank == 0:
325
        pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
326
    else:
327
        pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
328
    torch.cuda.synchronize()
329
    assert torch.all(tensor == 1).cpu().item()
330
331


332
333
334
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
335
336
337
338
339
340
341
342
343
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
344
        torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
345
346
347
348
349
350
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
351
        tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
352
    else:
353
        tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
354
    if torch.distributed.get_rank() in [0, 1]:
355
        pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
356
    else:
357
        pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
358
    torch.cuda.synchronize()
359
    if torch.distributed.get_rank() in [0, 2]:
360
        assert torch.all(tensor == 1).cpu().item()
361
    else:
362
        assert torch.all(tensor == 2).cpu().item()
363
364


365
366
367
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
368
369
370
371
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


372
373
374
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
375
376
377
378
379
380
381
382
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
383
384
385
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
386
    recv_tensors = [
387
        torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
388
389
        for i in range(pynccl_comm.world_size)
    ]
390
391
392
393
    recv_tensors[pynccl_comm.rank] = (
        torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
        * pynccl_comm.rank
    )
394
395
396
397
398
399
400
401
402

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


403
def test_ncclGetUniqueId():
404
405
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
406
407
408
409
410
411
412
413
414
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None