test_custom_all_reduce.py 5.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import random

import pytest
import ray
import torch
import torch.distributed as dist

11
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce  # noqa
12
from vllm.distributed.parallel_state import get_tp_group, graph_capture
13

14
15
16
17
18
from ..utils import (
    ensure_model_parallel_initialized,
    init_test_distributed_environment,
    multi_process_parallel,
)
19
20
21
22
23
24
25
26

random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
for i, v in enumerate(test_sizes):
    test_sizes[i] -= v % 8


@ray.remote(num_gpus=1, max_calls=1)
27
28
29
30
31
32
33
34
35
def graph_allreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pp_size,
    rank,
    distributed_init_port,
):
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
36
        m.delenv("HIP_VISIBLE_DEVICES", raising=False)
37
38
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
39
        init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
40
        ensure_model_parallel_initialized(tp_size, pp_size)
41
        group = get_tp_group().device_group
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        # A small all_reduce for warmup.
        # this is needed because device communicators might be created lazily
        # (e.g. NCCL). This will ensure that the communicator is initialized
        # before any communication happens, so that this group can be used for
        # graph capture immediately.
        data = torch.zeros(1)
        data = data.to(device=device)
        torch.distributed.all_reduce(data, group=group)
        torch.cuda.synchronize()
        del data

        # we use the first group to communicate once
        # and the second group to communicate twice
        # and so on
        # this is used to demonstrate that each group can
        # communicate independently
        num_communication = rank // tp_size + 1

        for sz in test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                with graph_capture(device=device) as graph_capture_context:
                    # use integers so result matches NCCL exactly
65
66
67
68
69
70
                    inp1 = torch.randint(
                        1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
                    )
                    inp2 = torch.randint(
                        1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
                    )
71
72
                    torch.cuda.synchronize()
                    graph = torch.cuda.CUDAGraph()
73
                    with torch.cuda.graph(graph, stream=graph_capture_context.stream):
74
75
76
77
78
79
80
81
82
83
                        for i in range(num_communication):
                            out1 = tensor_model_parallel_all_reduce(inp1)
                            # the input buffer is immediately modified to test
                            # synchronization
                            dist.all_reduce(inp1, group=group)
                            out2 = tensor_model_parallel_all_reduce(inp2)
                            dist.all_reduce(inp2, group=group)
                graph.replay()
                torch.testing.assert_close(out1, inp1)
                torch.testing.assert_close(out2, inp2)
84
85
86


@ray.remote(num_gpus=1, max_calls=1)
87
88
89
90
91
92
93
94
95
def eager_allreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pp_size,
    rank,
    distributed_init_port,
):
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
96
        m.delenv("HIP_VISIBLE_DEVICES", raising=False)
97
98
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
99
        init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
100
101
102
103
104
105
106
107

        # we use the first group to communicate once
        # and the second group to communicate twice
        # and so on
        # this is used to demonstrate that each group can
        # communicate independently
        num_communication = rank // tp_size + 1
        sz = 1024
108
        fa = get_tp_group().device_communicator.ca_comm
109
110
111
112
113
114
115
116
117
118
119
        inp = torch.ones(sz, dtype=torch.float32, device=device)
        out = inp
        for _ in range(num_communication):
            out = fa.all_reduce(out, registered=False)
        torch.testing.assert_close(out, inp * (tp_size**num_communication))

        inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
        out = inp
        for _ in range(num_communication):
            out = fa.all_reduce(out, registered=False)
        torch.testing.assert_close(out, inp * (tp_size**num_communication))
120
121


122
123
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
124
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
125
126
127
128
129
130
def test_custom_allreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pipeline_parallel_size,
    test_target,
):
131
132
133
    world_size = tp_size * pipeline_parallel_size
    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs to run the test.")
134
    multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)