test_pynccl.py 12.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import multiprocessing
4
import os
5
from typing import Dict, List
6
7
8

import pytest
import torch
9
import torch.distributed
10

11
from vllm.distributed.communication_op import (  # noqa
12
    tensor_model_parallel_all_reduce)
13
14
15
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
16
                                             get_world_group, graph_capture,
17
                                             init_distributed_environment)
18
from vllm.utils import update_environment_variables
19
20
21
22


def distributed_run(fn, world_size):
    number_of_processes = world_size
23
    processes: List[multiprocessing.Process] = []
24
    for i in range(number_of_processes):
25
        env: Dict[str, str] = {}
26
        env['RANK'] = str(i)
27
        env['LOCAL_RANK'] = str(i)
28
        env['WORLD_SIZE'] = str(number_of_processes)
29
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
30
31
32
33
34
35
36
37
38
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

39
40
41
    for p in processes:
        assert p.exitcode == 0

42

43
def worker_fn_wrapper(fn):
44
45
46
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
47
    def wrapped_fn(env):
48
        update_environment_variables(env)
49
50
51
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
52
        init_distributed_environment()
53
54
        fn()

55
    return wrapped_fn
56
57


58
@worker_fn_wrapper
59
def worker_fn():
60
61
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
62
63
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
64
    tensor = pynccl_comm.all_reduce(tensor)
65
    torch.cuda.synchronize()
66
    assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
67
68
69
70
71
72
73
74


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


75
@worker_fn_wrapper
76
def multiple_allreduce_worker_fn():
77
78
79
80
81
82
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
83
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
84
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
85
86
87
88
89
90
91
92
93
94
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        tensor = pynccl_comm.all_reduce(tensor)
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 4).cpu().item()
    else:
        tensor = pynccl_comm.all_reduce(tensor)
        torch.cuda.synchronize()
        assert torch.all(tensor == 2).cpu().item()
95
96
97


@pytest.mark.skipif(torch.cuda.device_count() < 4,
98
                    reason="Need at least 4 GPUs to run the test.")
99
def test_pynccl_multiple_allreduce():
100
    # this tests pynccl for multiple tp groups, in a standalone way
101
    # i.e. call `pynccl_comm.all_reduce` directly
102
    distributed_run(multiple_allreduce_worker_fn, 4)
103
104
105


@worker_fn_wrapper
106
def multiple_allreduce_with_vllm_worker_fn():
107
108
109
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
110
    with graph_capture(device=device):
111
112
113
114
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
115
            torch.cuda.synchronize()
116
            assert torch.all(tensor == 4).cpu().item()
117
118
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
119
            torch.cuda.synchronize()
120
            assert torch.all(tensor == 2).cpu().item()
121
122
123
124


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
125
def test_pynccl_multiple_allreduce_with_vllm():
126
127
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
128
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
129
130


131
@worker_fn_wrapper
132
133
134
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
135
136
        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                         device=get_world_group().device)
137
        # run something in the default stream to initialize torch engine
138
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
139
        torch.cuda.synchronize()
140
        with torch.cuda.graph(graph):
141
            a_out = pynccl_comm.all_reduce(a)
142
        torch.cuda.synchronize()
143
        graph.replay()
144
        torch.cuda.synchronize()
145
        assert torch.all(a_out == pynccl_comm.world_size).cpu().item()
146
147


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@worker_fn_wrapper
def all_gather_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    result = torch.zeros(num_elems * world_size,
                         dtype=torch.float32,
                         device=device)

    expected = torch.cat([
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]).to(device)

169
    pynccl_comm.all_gather(result, tensor)
170
    torch.cuda.synchronize()
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    assert (num_elems % world_size == 0)
    result = torch.zeros(num_elems // world_size,
                         dtype=torch.float32,
                         device=device)

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
    expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
                   for tensor in all_tensors).to(device)

206
    pynccl_comm.reduce_scatter(result, tensor)
207
    torch.cuda.synchronize()
208
209
210
211
212
213
214
215
216
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


217
218
219
220
221
222
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


223
224
@worker_fn_wrapper
def send_recv_worker_fn():
225
226
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
227
228
229
230
231
232
    if pynccl_comm.rank == 0:
        tensor = torch.ones(16, 1024, 1024,
                            dtype=torch.float32).cuda(pynccl_comm.rank)
    else:
        tensor = torch.empty(16, 1024, 1024,
                             dtype=torch.float32).cuda(pynccl_comm.rank)
233
234
235
236
237
238
239

    if pynccl_comm.rank == 0:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
240
    torch.cuda.synchronize()
241
    assert torch.all(tensor == 1).cpu().item()
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
        tensor = 2 * torch.ones(
            16, 1024, 1024, dtype=torch.float32, device=device)
    else:
        tensor = torch.empty(16,
                             1024,
                             1024,
                             dtype=torch.float32,
                             device=device)
270
271
272
273
274
275
    if torch.distributed.get_rank() in [0, 1]:
        pynccl_comm.send(tensor,
                         dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
    else:
        pynccl_comm.recv(tensor,
                         src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
276
    torch.cuda.synchronize()
277
    if torch.distributed.get_rank() in [0, 2]:
278
        assert torch.all(tensor == 1).cpu().item()
279
    else:
280
        assert torch.all(tensor == 2).cpu().item()
281
282
283
284
285
286
287
288


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
    recv_tensors = [
        torch.empty(16,
                    1024,
                    1024,
                    dtype=torch.float32,
                    device=pynccl_comm.device)
        for i in range(pynccl_comm.world_size)
    ]
    recv_tensors[pynccl_comm.rank] = torch.ones(
        16, 1024, 1024, dtype=torch.float32,
        device=pynccl_comm.device) * pynccl_comm.rank

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


321
def test_ncclGetUniqueId():
322
323
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
324
325
326
327
328
329
330
331
332
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None