test_pynccl.py 12.9 KB
Newer Older
1
import multiprocessing
2
import os
3
from typing import Dict, List
4
5
6

import pytest
import torch
7
import torch.distributed
8

9
from vllm.distributed.communication_op import (  # noqa
10
    tensor_model_parallel_all_reduce)
11
12
13
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
14
                                             get_world_group, graph_capture,
15
                                             init_distributed_environment)
16
from vllm.utils import update_environment_variables
17
18
19
20


def distributed_run(fn, world_size):
    number_of_processes = world_size
21
    processes: List[multiprocessing.Process] = []
22
    for i in range(number_of_processes):
23
        env: Dict[str, str] = {}
24
        env['RANK'] = str(i)
25
        env['LOCAL_RANK'] = str(i)
26
        env['WORLD_SIZE'] = str(number_of_processes)
27
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
28
29
30
31
32
33
34
35
36
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

37
38
39
    for p in processes:
        assert p.exitcode == 0

40

41
def worker_fn_wrapper(fn):
42
43
44
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
45
    def wrapped_fn(env):
46
        update_environment_variables(env)
47
48
49
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
50
        init_distributed_environment()
51
52
        fn()

53
    return wrapped_fn
54
55


56
@worker_fn_wrapper
57
def worker_fn():
58
59
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
60
61
62
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
    with pynccl_comm.change_state(enable=True):
63
        tensor = pynccl_comm.all_reduce(tensor)
64
    torch.cuda.synchronize()
65
    result = tensor.mean().cpu().item()
66
    assert result == pynccl_comm.world_size
67
68
69
70
71
72
73
74


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


75
@worker_fn_wrapper
76
def multiple_allreduce_worker_fn():
77
78
79
80
81
82
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
83
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
84
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
85
86
87
    with pynccl_comm.change_state(enable=True):
        # two groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
88
89
            tensor = pynccl_comm.all_reduce(tensor)
            tensor = pynccl_comm.all_reduce(tensor)
90
            torch.cuda.synchronize()
91
92
93
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
94
            tensor = pynccl_comm.all_reduce(tensor)
95
            torch.cuda.synchronize()
96
97
            result = tensor.mean().cpu().item()
            assert result == 2
98
99
100


@pytest.mark.skipif(torch.cuda.device_count() < 4,
101
                    reason="Need at least 4 GPUs to run the test.")
102
def test_pynccl_multiple_allreduce():
103
    # this tests pynccl for multiple tp groups, in a standalone way
104
    # i.e. call `pynccl_comm.all_reduce` directly
105
    distributed_run(multiple_allreduce_worker_fn, 4)
106
107
108


@worker_fn_wrapper
109
def multiple_allreduce_with_vllm_worker_fn():
110
111
112
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
113
    with graph_capture():
114
115
116
117
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
118
            torch.cuda.synchronize()
119
120
121
122
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
123
            torch.cuda.synchronize()
124
125
126
127
128
129
            result = tensor.mean().cpu().item()
            assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
130
def test_pynccl_multiple_allreduce_with_vllm():
131
132
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
133
    distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
134
135


136
@worker_fn_wrapper
137
138
139
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
140
141
        pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                         device=get_world_group().device)
142
        # run something in the default stream to initialize torch engine
143
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
144
        torch.cuda.synchronize()
145
146
147
        with torch.cuda.graph(
                graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
                    enable=True):
148
            a_out = pynccl_comm.all_reduce(a)
149
        torch.cuda.synchronize()
150
        graph.replay()
151
        torch.cuda.synchronize()
152
        assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
153
154


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@worker_fn_wrapper
def all_gather_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    result = torch.zeros(num_elems * world_size,
                         dtype=torch.float32,
                         device=device)

    expected = torch.cat([
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]).to(device)

    with pynccl_comm.change_state(enable=True):
        pynccl_comm.all_gather(result, tensor)
178
    torch.cuda.synchronize()
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
    distributed_run(all_gather_worker_fn, 2)


@worker_fn_wrapper
def reduce_scatter_worker_fn():
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)

    rank = pynccl_comm.rank
    world_size = pynccl_comm.world_size
    device = f'cuda:{pynccl_comm.rank}'

    num_elems = 1000
    tensor = torch.arange(num_elems, dtype=torch.float32,
                          device=device) + rank * num_elems
    assert (num_elems % world_size == 0)
    result = torch.zeros(num_elems // world_size,
                         dtype=torch.float32,
                         device=device)

    # Calculate expected result for this rank's chunk
    scattered_size = num_elems // world_size
    all_tensors = [
        torch.arange(num_elems, dtype=torch.float32) + r * num_elems
        for r in range(world_size)
    ]
    expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
                   for tensor in all_tensors).to(device)

    with pynccl_comm.change_state(enable=True):
        pynccl_comm.reduce_scatter(result, tensor)
216
    torch.cuda.synchronize()
217
218
219
220
221
222
223
224
225
    torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
    distributed_run(reduce_scatter_worker_fn, 2)


226
227
228
229
230
231
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


232
233
@worker_fn_wrapper
def send_recv_worker_fn():
234
235
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
236
237
238
239
240
241
242
243
    if pynccl_comm.rank == 0:
        tensor = torch.ones(16, 1024, 1024,
                            dtype=torch.float32).cuda(pynccl_comm.rank)
    else:
        tensor = torch.empty(16, 1024, 1024,
                             dtype=torch.float32).cuda(pynccl_comm.rank)
    with pynccl_comm.change_state(enable=True):
        if pynccl_comm.rank == 0:
244
245
246
            pynccl_comm.send(tensor,
                             dst=(pynccl_comm.rank + 1) %
                             pynccl_comm.world_size)
247
        else:
248
249
250
            pynccl_comm.recv(tensor,
                             src=(pynccl_comm.rank - 1) %
                             pynccl_comm.world_size)
251
    torch.cuda.synchronize()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    result = tensor.mean().cpu().item()
    assert result == 1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
    distributed_run(send_recv_worker_fn, 2)


@worker_fn_wrapper
def multiple_send_recv_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
        torch.distributed.new_group(ranks=[1, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
    if torch.distributed.get_rank() == 0:
        tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    elif torch.distributed.get_rank() == 1:
        tensor = 2 * torch.ones(
            16, 1024, 1024, dtype=torch.float32, device=device)
    else:
        tensor = torch.empty(16,
                             1024,
                             1024,
                             dtype=torch.float32,
                             device=device)
    with pynccl_comm.change_state(enable=True):
        if torch.distributed.get_rank() in [0, 1]:
284
285
286
            pynccl_comm.send(tensor,
                             dst=(pynccl_comm.rank + 1) %
                             pynccl_comm.world_size)
287
        else:
288
289
290
            pynccl_comm.recv(tensor,
                             src=(pynccl_comm.rank - 1) %
                             pynccl_comm.world_size)
291
    torch.cuda.synchronize()
292
293
294
295
296
297
298
299
300
301
302
303
304
    result = tensor.mean().cpu().item()
    if torch.distributed.get_rank() in [0, 2]:
        assert result == 1
    else:
        assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
    distributed_run(multiple_send_recv_worker_fn, 4)


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_broadcast():
    distributed_run(broadcast_worker_fn, 4)


@worker_fn_wrapper
def broadcast_worker_fn():
    # Test broadcast for every root rank.
    # Essentially this is an all-gather operation.
    pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
                                     device=get_world_group().device)
    recv_tensors = [
        torch.empty(16,
                    1024,
                    1024,
                    dtype=torch.float32,
                    device=pynccl_comm.device)
        for i in range(pynccl_comm.world_size)
    ]
    recv_tensors[pynccl_comm.rank] = torch.ones(
        16, 1024, 1024, dtype=torch.float32,
        device=pynccl_comm.device) * pynccl_comm.rank

    for i in range(pynccl_comm.world_size):
        pynccl_comm.broadcast(recv_tensors[i], src=i)
        # the broadcast op might be launched in a different stream
        # need to synchronize to make sure the tensor is ready
        torch.cuda.synchronize()
        assert torch.all(recv_tensors[i] == i).cpu().item()


337
def test_ncclGetUniqueId():
338
339
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
340
341
342
343
344
345
346
347
348
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None