test_pynccl.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import multiprocessing
5
import os
6

7
import numpy as np
8
9
import pytest
import torch
10
import torch.distributed
11

12
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce  # noqa
13
14
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
15
16
17
18
19
20
from vllm.distributed.parallel_state import (
    ensure_model_parallel_initialized,
    get_world_group,
    graph_capture,
    init_distributed_environment,
)
21
from vllm.utils import update_environment_variables
22
23
24
25


def distributed_run(fn, world_size):
    number_of_processes = world_size
26
    processes: list[multiprocessing.Process] = []
27
    for i in range(number_of_processes):
28
        env: dict[str, str] = {}
29
30
31
32
33
34
35
        env["RANK"] = str(i)
        env["LOCAL_RANK"] = str(i)
        env["WORLD_SIZE"] = str(number_of_processes)
        env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
        env["MASTER_ADDR"] = "localhost"
        env["MASTER_PORT"] = "12345"
        p = multiprocessing.Process(target=fn, args=(env,))
36
37
38
39
40
41
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

42
43
44
    for p in processes:
        assert p.exitcode == 0

45

46
def worker_fn_wrapper(fn):
47
48
49
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
50
    def wrapped_fn(env):
51
        update_environment_variables(env)
52
        local_rank = os.environ["LOCAL_RANK"]
53
54
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
55
        init_distributed_environment()
56
57
        fn()

58
    return wrapped_fn
59
60


61
@worker_fn_wrapper
62
def worker_fn():
63
64
65
66
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
67
    tensor = pynccl_comm.all_reduce(tensor)
68
    torch.cuda.synchronize()
69
    assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
70
71


72
73
74
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
75
76
77
78
def test_pynccl():
    distributed_run(worker_fn, 2)


79
@worker_fn_wrapper
80
def multiple_allreduce_worker_fn():
81
82
83
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
84
        torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
85
86
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
87
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
88
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
89
90
91
92
93
94
95
96
97
98
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        tensor = pynccl_comm.all_reduce(tensor)
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 4).cpu().item()
    else:
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 2).cpu().item()
99
100


101
102
103
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
104
def test_pynccl_multiple_allreduce():
105
    # this tests pynccl for multiple tp groups, in a standalone way
106
    # i.e. call `pynccl_comm.all_reduce` directly
107
    distributed_run(multiple_allreduce_worker_fn, 4)
108
109
110


@worker_fn_wrapper
111
def multiple_allreduce_with_vllm_worker_fn():
112
113
114
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
115
    with graph_capture(device=device):
116
117
118
119
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
120
            torch.cuda.synchronize()
121
            assert torch.all(tensor == 4).cpu().item()
122
123
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
124
            torch.cuda.synchronize()
125
            assert torch.all(tensor == 2).cpu().item()
126
127


128
129
130
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
131
def test_pynccl_multiple_allreduce_with_vllm():
132
133
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
134
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
135
136


137
@worker_fn_wrapper
138
139
140
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
141
142
143
        pynccl_comm = PyNcclCommunicator(
            get_world_group().cpu_group, device=get_world_group().device
        )
144
        # run something in the default stream to initialize torch engine
145
        a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}")
146
        torch.cuda.synchronize()
147
        with torch.cuda.graph(graph):
148
            a_out = pynccl_comm.all_reduce(a)
149
        torch.cuda.synchronize()
150
        graph.replay()
151
        torch.cuda.synchronize()
152
        assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
153
154


155
156
@worker_fn_wrapper
def all_gather_worker_fn():
157
158
159
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
160
161
162

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
163
    device = f"cuda:{pynccl_comm.rank}"
164
165

    num_elems = 1000
166
167
168
169
170
171
172
173
174
175
176
    tensor = (
        torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
    )
    result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device)

    expected = torch.cat(
        [
            torch.arange(num_elems, dtype=torch.float32) + r * num_elems
            for r in range(world_size)
        ]
    ).to(device)
177

178
    pynccl_comm.all_gather(result, tensor)
179
    torch.cuda.synchronize()
180
181
182
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


183
184
185
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
186
187
188
189
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


190
191
@worker_fn_wrapper
def all_gatherv_worker_fn():
192
193
194
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
195
196
197

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
198
    device = f"cuda:{pynccl_comm.rank}"
199
200
201
202

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sizes[rank]
203
    tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
204
205
    result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)

206
207
208
209
210
211
    expected = torch.cat(
        [
            torch.arange(sizes[r], dtype=torch.float32) + r * 100
            for r in range(world_size)
        ]
    ).to(device)
212
213
214
215
216
217

    pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


218
219
220
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
221
222
223
224
def test_pynccl_all_gatherv():
    distributed_run(all_gatherv_worker_fn, 2)


225
226
@worker_fn_wrapper
def reduce_scatter_worker_fn():
227
228
229
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
230
231
232

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
233
    device = f"cuda:{pynccl_comm.rank}"
234
235

    num_elems = 1000
236
237
238
239
240
    tensor = (
        torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems
    )
    assert num_elems % world_size == 0
    result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device)
241
242
243
244
245
246
247

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
248
249
250
251
    expected = sum(
        tensor[rank * scattered_size : (rank + 1) * scattered_size]
        for tensor in all_tensors
    ).to(device)
252

253
    pynccl_comm.reduce_scatter(result, tensor)
254
    torch.cuda.synchronize()
255
256
257
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


258
259
260
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
261
262
263
264
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


265
266
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
267
268
269
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
270
271
272

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
273
    device = f"cuda:{pynccl_comm.rank}"
274
275
276
277

    assert world_size <= 8
    sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
    num_elems = sum(sizes)
278
    tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)

    # Calculate expected result for this rank's chunk
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * 100
        for r in range(world_size)
    ]
    sizes_cumsum = np.cumsum(sizes)
    start = 0 if rank == 0 else sizes_cumsum[rank - 1]
    end = sizes_cumsum[rank]
    expected = sum(tensor[start:end] for tensor in all_tensors).to(device)

    pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
    torch.cuda.synchronize()
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


296
297
298
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
299
300
301
302
def test_pynccl_reduce_scatterv():
    distributed_run(reduce_scatterv_worker_fn, 2)


303
304
305
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
306
307
308
309
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


310
311
@worker_fn_wrapper
def send_recv_worker_fn():
312
313
314
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
315
    if pynccl_comm.rank == 0:
316
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
317
    else:
318
        tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank)
319
320

    if pynccl_comm.rank == 0:
321
        pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
322
    else:
323
        pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
324
    torch.cuda.synchronize()
325
    assert torch.all(tensor == 1).cpu().item()
326
327


328
329
330
@pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test."
)
331
332
333
334
335
336
337
338
339
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
340
        torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
341
342
343
344
345
346
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
347
        tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
348
    else:
349
        tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device)
350
    if torch.distributed.get_rank() in [0, 1]:
351
        pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
352
    else:
353
        pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
354
    torch.cuda.synchronize()
355
    if torch.distributed.get_rank() in [0, 2]:
356
        assert torch.all(tensor == 1).cpu().item()
357
    else:
358
        assert torch.all(tensor == 2).cpu().item()
359
360


361
362
363
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
364
365
366
367
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


368
369
370
@pytest.mark.skipif(
    torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test."
)
371
372
373
374
375
376
377
378
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
379
380
381
    pynccl_comm = PyNcclCommunicator(
        get_world_group().cpu_group, device=get_world_group().device
    )
382
    recv_tensors = [
383
        torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
384
385
        for i in range(pynccl_comm.world_size)
    ]
386
387
388
389
    recv_tensors[pynccl_comm.rank] = (
        torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device)
        * pynccl_comm.rank
    )
390
391
392
393
394
395
396
397
398

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


399
def test_ncclGetUniqueId():
400
401
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
402
403
404
405
406
407
408
409
410
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None