test_custom_all_reduce.py 4.67 KB
Newer Older
1
import os
2
3
4
5
6
7
8
import random

import pytest
import ray
import torch
import torch.distributed as dist

9
from vllm.distributed.communication_op import (  # noqa
10
    tensor_model_parallel_all_reduce)
11
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
12
                                             get_tp_group, graph_capture)
13

14
15
from ..utils import (ensure_model_parallel_initialized,
                     init_test_distributed_environment,
16
                     multi_process_tensor_parallel)
17
18
19
20
21
22
23
24

random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
for i, v in enumerate(test_sizes):
    test_sizes[i] -= v % 8


@ray.remote(num_gpus=1, max_calls=1)
25
def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
26
27
28
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
29
    init_test_distributed_environment(tp_size, pp_size, rank,
30
                                      distributed_init_port)
31
32
    ensure_model_parallel_initialized(tp_size, pp_size)
    group = get_tensor_model_parallel_group().device_group
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    # A small all_reduce for warmup.
    # this is needed because device communicators might be created lazily
    # (e.g. NCCL). This will ensure that the communicator is initialized
    # before any communication happens, so that this group can be used for
    # graph capture immediately.
    data = torch.zeros(1)
    data = data.to(device=device)
    torch.distributed.all_reduce(data, group=group)
    torch.cuda.synchronize()
    del data

    # we use the first group to communicate once
    # and the second group to communicate twice
    # and so on
    # this is used to demonstrate that each group can
    # communicate independently
    num_communication = rank // tp_size + 1

52
53
    for sz in test_sizes:
        for dtype in [torch.float32, torch.float16, torch.bfloat16]:
54
            with graph_capture() as graph_capture_context:
55
56
57
58
59
60
61
62
63
64
65
                # use integers so result matches NCCL exactly
                inp1 = torch.randint(1,
                                     16, (sz, ),
                                     dtype=dtype,
                                     device=torch.cuda.current_device())
                inp2 = torch.randint(1,
                                     16, (sz, ),
                                     dtype=dtype,
                                     device=torch.cuda.current_device())
                torch.cuda.synchronize()
                graph = torch.cuda.CUDAGraph()
66
67
                with torch.cuda.graph(graph,
                                      stream=graph_capture_context.stream):
68
69
70
71
72
73
74
                    for i in range(num_communication):
                        out1 = tensor_model_parallel_all_reduce(inp1)
                        # the input buffer is immediately modified to test
                        # synchronization
                        dist.all_reduce(inp1, group=group)
                        out2 = tensor_model_parallel_all_reduce(inp2)
                        dist.all_reduce(inp2, group=group)
75
76
77
78
79
80
            graph.replay()
            assert torch.allclose(out1, inp1)
            assert torch.allclose(out2, inp2)


@ray.remote(num_gpus=1, max_calls=1)
81
def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
82
83
84
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
85
    init_test_distributed_environment(tp_size, pp_size, rank,
86
87
                                      distributed_init_port)

88
89
90
91
92
93
    # we use the first group to communicate once
    # and the second group to communicate twice
    # and so on
    # this is used to demonstrate that each group can
    # communicate independently
    num_communication = rank // tp_size + 1
94
    sz = 1024
95
    fa = get_tp_group().ca_comm
96
    inp = torch.ones(sz, dtype=torch.float32, device=device)
97
98
99
100
    out = inp
    for _ in range(num_communication):
        out = fa.all_reduce_unreg(out)
    assert torch.allclose(out, inp * (tp_size**num_communication))
101
102

    inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
103
104
105
106
    out = inp
    for _ in range(num_communication):
        out = fa.all_reduce_unreg(out)
    assert torch.allclose(out, inp * (tp_size**num_communication))
107
108


109
110
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
111
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
112
113
114
115
116
def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
    world_size = tp_size * pipeline_parallel_size
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")
    multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)