test_pynccl.py 6.21 KB
Newer Older
1
import multiprocessing
2
import os
3
4
5
6

import pytest
import torch

7
from vllm.distributed.communication_op import (  # noqa
8
    graph_capture, tensor_model_parallel_all_reduce)
9
10
11
12
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
                                             init_distributed_environment)
13
from vllm.utils import update_environment_variables
14
15
16
17
18
19


def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
20
        env = {}
21
        env['RANK'] = str(i)
22
        env['LOCAL_RANK'] = str(i)
23
        env['WORLD_SIZE'] = str(number_of_processes)
24
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
25
26
27
28
29
30
31
32
33
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

34
35
36
    for p in processes:
        assert p.exitcode == 0

37

38
def worker_fn_wrapper(fn):
39
40
41
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
42
    def wrapped_fn(env):
43
        update_environment_variables(env)
44
45
46
        local_rank = os.environ['LOCAL_RANK']
        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
47
        init_distributed_environment()
48
49
        fn()

50
    return wrapped_fn
51
52


53
@worker_fn_wrapper
54
def worker_fn():
55
56
57
58
59
    pynccl_comm = PyNcclCommunicator()
    tensor = torch.ones(16, 1024, 1024,
                        dtype=torch.float32).cuda(pynccl_comm.rank)
    with pynccl_comm.change_state(enable=True):
        pynccl_comm.all_reduce(tensor)
60
    result = tensor.mean().cpu().item()
61
    assert result == pynccl_comm.world_size
62
63
64
65
66
67
68
69


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


70
71
72
73
74
75
76
77
@worker_fn_wrapper
def multiple_tp_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
78
    pynccl_comm = PyNcclCommunicator(group=group, device=device)
79
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
80
81
82
83
84
85
86
87
88
89
90
    with pynccl_comm.change_state(enable=True):
        # two groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            pynccl_comm.all_reduce(tensor)
            pynccl_comm.all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            pynccl_comm.all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 2
91
92
93


@pytest.mark.skipif(torch.cuda.device_count() < 4,
94
                    reason="Need at least 4 GPUs to run the test.")
95
def test_pynccl_multiple_tp():
96
    # this tests pynccl for multiple tp groups, in a standalone way
97
    # i.e. call `pynccl_comm.all_reduce` directly
98
99
100
101
102
103
104
105
    distributed_run(multiple_tp_worker_fn, 4)


@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    ensure_model_parallel_initialized(2, 2)
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
106
    with graph_capture():
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp_with_vllm():
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
    distributed_run(multiple_tp_with_vllm_worker_fn, 4)
125
126


127
@worker_fn_wrapper
128
129
130
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
131
        pynccl_comm = PyNcclCommunicator()
132
        # run something in the default stream to initialize torch engine
133
        a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
134
        torch.cuda.synchronize()
135
136
137
        with torch.cuda.graph(
                graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
                    enable=True):
138
139
            # operation during the graph capture is recorded but not executed
            # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
140
141
142
            pynccl_comm.all_reduce(a)
        pynccl_comm.stream.synchronize()
        assert a.mean().cpu().item() == pynccl_comm.world_size**0
143
        graph.replay()
144
145
        pynccl_comm.stream.synchronize()
        assert a.mean().cpu().item() == pynccl_comm.world_size**1
146
147
148
149
150
151
152
153
154


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
155
156
    lib = NCCLLibrary()
    unique_id = lib.ncclGetUniqueId()
157
158
159
160
161
162
163
164
165
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None