test_pynccl.py 5.91 KB
Newer Older
1
2
3
4
5
import multiprocessing

import pytest
import torch

6
7
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
8
9
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
                                                          ncclGetUniqueId)
10
11
12
from vllm.distributed.parallel_state import (
    ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
    init_distributed_environment, with_pynccl_for_all_reduce)
13
from vllm.utils import update_environment_variables
14
15
16
17
18
19


def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
20
        env = {}
21
        env['RANK'] = str(i)
22
        env['LOCAL_RANK'] = str(i)
23
        env['WORLD_SIZE'] = str(number_of_processes)
24
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
25
26
27
28
29
30
31
32
33
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

34
35
36
    for p in processes:
        assert p.exitcode == 0

37

38
def worker_fn_wrapper(fn):
39
40
41
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
42
    def wrapped_fn(env):
43
        update_environment_variables(env)
44
        init_distributed_environment()
45
46
        fn()

47
    return wrapped_fn
48
49


50
@worker_fn_wrapper
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def worker_fn():
    comm = NCCLCommunicator()
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
    comm.all_reduce(tensor)
    result = tensor.mean().cpu().item()
    assert result == comm.world_size


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


65
66
67
68
69
70
71
72
73
@worker_fn_wrapper
def multiple_tp_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    groups = [
        torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
        torch.distributed.new_group(ranks=[2, 3], backend="gloo")
    ]
    group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
    comm = NCCLCommunicator(group=group, device=device)
74
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
75
76
77
78
79
80
81
82
83
84
85
86
87
    # two groups can communicate independently
    if torch.distributed.get_rank() in [0, 1]:
        comm.all_reduce(tensor)
        comm.all_reduce(tensor)
        result = tensor.mean().cpu().item()
        assert result == 4
    else:
        comm.all_reduce(tensor)
        result = tensor.mean().cpu().item()
        assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
88
                    reason="Need at least 4 GPUs to run the test.")
89
def test_pynccl_multiple_tp():
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    # this tests pynccl for multiple tp groups, in a standalone way
    # i.e. call `comm.all_reduce` directly
    distributed_run(multiple_tp_worker_fn, 4)


@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
    device = torch.device(f"cuda:{torch.distributed.get_rank()}")
    torch.cuda.set_device(torch.distributed.get_rank())
    ensure_model_parallel_initialized(2, 2)
    pynccl_utils.init_process_group(
        group=get_tensor_model_parallel_cpu_group())
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
    with with_pynccl_for_all_reduce():
        # two tp groups can communicate independently
        if torch.distributed.get_rank() in [0, 1]:
            tensor = tensor_model_parallel_all_reduce(tensor)
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 4
        else:
            tensor = tensor_model_parallel_all_reduce(tensor)
            result = tensor.mean().cpu().item()
            assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp_with_vllm():
    # this tests pynccl for multiple tp groups, together with vllm
    # i.e. call `tensor_model_parallel_all_reduce`
    distributed_run(multiple_tp_with_vllm_worker_fn, 4)
122
123


124
@worker_fn_wrapper
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
        comm = NCCLCommunicator()
        # run something in the default stream to initialize torch engine
        a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
        torch.cuda.synchronize()
        with torch.cuda.graph(graph, stream=comm.stream):
            # operation during the graph capture is recorded but not executed
            # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
            comm.all_reduce(a)
        comm.stream.synchronize()
        assert a.mean().cpu().item() == comm.world_size**0
        graph.replay()
        comm.stream.synchronize()
        assert a.mean().cpu().item() == comm.world_size**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
    unique_id = ncclGetUniqueId()
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None