test_pynccl.py 3.34 KB
Newer Older
1
2
3
4
5
import multiprocessing

import pytest
import torch

6
7
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
                                                          ncclGetUniqueId)
8
from vllm.distributed.parallel_state import init_distributed_environment
9
from vllm.utils import update_environment_variables
10
11
12
13
14
15


def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
16
        env = {}
17
        env['RANK'] = str(i)
18
        env['LOCAL_RANK'] = str(i)
19
        env['WORLD_SIZE'] = str(number_of_processes)
20
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
21
22
23
24
25
26
27
28
29
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

30
31
32
    for p in processes:
        assert p.exitcode == 0

33

34
def worker_fn_wrapper(fn):
35
36
37
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
38
    def wrapped_fn(env):
39
        update_environment_variables(env)
40
        init_distributed_environment()
41
42
        fn()

43
    return wrapped_fn
44
45


46
@worker_fn_wrapper
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def worker_fn():
    comm = NCCLCommunicator()
    tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
    comm.all_reduce(tensor)
    result = tensor.mean().cpu().item()
    assert result == comm.world_size


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
    distributed_run(worker_fn, 2)


61
@worker_fn_wrapper
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def worker_fn_with_cudagraph():
    with torch.no_grad():
        graph = torch.cuda.CUDAGraph()
        comm = NCCLCommunicator()
        # run something in the default stream to initialize torch engine
        a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
        torch.cuda.synchronize()
        with torch.cuda.graph(graph, stream=comm.stream):
            # operation during the graph capture is recorded but not executed
            # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
            comm.all_reduce(a)
        comm.stream.synchronize()
        assert a.mean().cpu().item() == comm.world_size**0
        graph.replay()
        comm.stream.synchronize()
        assert a.mean().cpu().item() == comm.world_size**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
    distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
    unique_id = ncclGetUniqueId()
    # `list(unique_id.internal)` is something like this:
    # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # as long as the function doesn't raise an exception, we're good
    assert unique_id is not None