test_quick_all_reduce.py 7.94 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import multiprocessing
5
6
7
8
9
10
11
import random

import pytest
import ray
import torch
import torch.distributed as dist

12
from vllm import _custom_ops as ops
13
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce  # noqa
14
from vllm.distributed.parallel_state import get_tp_group, graph_capture
15
16
from vllm.platforms import current_platform

17
18
19
20
21
from ..utils import (
    ensure_model_parallel_initialized,
    init_test_distributed_environment,
    multi_process_parallel,
)
22
23
24
25

torch.manual_seed(42)
random.seed(44)
# Size over 8MB is sufficient for custom quick allreduce.
26
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
for i, v in enumerate(test_sizes):
    test_sizes[i] -= v % 8


@ray.remote(num_gpus=1, max_calls=1)
def graph_quickreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pp_size,
    rank,
    distributed_init_port,
):
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
        device = torch.device(f"cuda:{rank}")
42
        torch.accelerator.set_device_index(device)
43
        init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
44
        ensure_model_parallel_initialized(tp_size, pp_size)
45
        group = get_tp_group().device_group
46
47
48
49
50
51
52
53
54

        # A small all_reduce for warmup.
        # this is needed because device communicators might be created lazily
        # (e.g. NCCL). This will ensure that the communicator is initialized
        # before any communication happens, so that this group can be used for
        # graph capture immediately.
        data = torch.zeros(1)
        data = data.to(device=device)
        torch.distributed.all_reduce(data, group=group)
55
        torch.accelerator.synchronize()
56
57
58
59
60
61
62
63
64
65
66
67
        del data

        # we use the first group to communicate once
        # and the second group to communicate twice
        # and so on
        # this is used to demonstrate that each group can
        # communicate independently
        num_communication = rank // tp_size + 1

        for sz in test_sizes:
            for dtype in [torch.float16, torch.bfloat16]:
                with graph_capture(device=device) as graph_capture_context:
68
69
70
71
                    device_idx = torch.accelerator.current_device_index()
                    inp1 = torch.randint(1, 23, (sz,), dtype=dtype, device=device_idx)
                    inp2 = torch.randint(-23, 1, (sz,), dtype=dtype, device=device_idx)

72
                    torch.accelerator.synchronize()
73
                    graph = torch.cuda.CUDAGraph()
74
                    with torch.cuda.graph(graph, stream=graph_capture_context.stream):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
                        for _ in range(num_communication):
                            out1 = tensor_model_parallel_all_reduce(inp1)
                            dist.all_reduce(inp1, group=group)
                            out2 = tensor_model_parallel_all_reduce(inp2)
                            dist.all_reduce(inp2, group=group)
                graph.replay()
                torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1)
                torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1)


@ray.remote(num_gpus=1, max_calls=1)
def eager_quickreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pp_size,
    rank,
    distributed_init_port,
):
    with monkeypatch.context() as m:
        m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
        device = torch.device(f"cuda:{rank}")
96
        torch.accelerator.set_device_index(device)
97

98
        init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
99
100
101
102

        # Size over 8MB is sufficient for custom quick allreduce.
        sz = 16 * 1024 * 1024
        fa = get_tp_group().device_communicator.qr_comm
103
104
105
        inp = torch.tensor(
            [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
        )
106
107
108
        out = fa.quick_all_reduce(inp)
        torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)

109
110
111
        inp = torch.tensor(
            [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
        )
112
113
114
115
        out = fa.quick_all_reduce(inp)
        torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)


116
117
118
@pytest.mark.skipif(
    not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
119
120
121
122
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
123
124
125
126
127
128
129
def test_custom_quick_allreduce(
    monkeypatch: pytest.MonkeyPatch,
    tp_size,
    pipeline_parallel_size,
    test_target,
    quant_mode,
):
130
    world_size = tp_size * pipeline_parallel_size
131
    if world_size > torch.accelerator.device_count():
132
133
134
135
        pytest.skip("Not enough GPUs to run the test.")

    monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)

136
    multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
137
138
139
140
141
142
143
144
145


def qr_variable_input(rank, world_size):
    """
    When the tensor parallelism is set to 4 or 8, frequent changes
    in the input shape can cause QuickReduce to hang (this issue
    has been observed with the gpt_oss model).
    """
    device = torch.device(f"cuda:{rank}")
146
    torch.accelerator.set_device_index(device)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    qr_max_size = None  # MB
    _ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
    ranks = []
    for i in range(world_size):
        ranks.append(i)
    dist.init_process_group(
        backend="nccl",
        init_method="tcp://127.0.0.1:29500",
        rank=rank,
        world_size=world_size,
    )
    cpu_group = torch.distributed.new_group(ranks, backend="nccl")

    handle = ops.qr_get_handle(_ptr)
    world_size = dist.get_world_size(group=cpu_group)
    handles = [None] * world_size
    dist.all_gather_object(handles, handle, group=cpu_group)
    ops.qr_open_handles(_ptr, handles)

    num = 1
    s1 = 1024
    while num < 50000:  # 50000 is sufficient to identify issues.
        dtype = torch.float16
170
        device_idx = torch.accelerator.current_device_index()
171
172
        if num % 2 == 0:
            s2 = 1024
173
            inp1 = torch.zeros((s1, s2), dtype=dtype, device=device_idx)
174
175
        else:
            s2 = 2048
176
            inp1 = torch.ones((s1, s2), dtype=dtype, device=device_idx)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        result = torch.empty_like(inp1)
        # FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
        ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
        try:
            if inp1[0, 0] == 0:
                assert torch.all(result == 0)
            else:
                assert torch.all(result == world_size)
        except AssertionError:
            print("Assertion failed! Allreduce results are incorrect.")
            raise
        num += 1


@pytest.mark.skipif(
    not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
@pytest.mark.parametrize("tp_size", [4, 8])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
    world_size = tp_size * pipeline_parallel_size
198
    if world_size > torch.accelerator.device_count():
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        pytest.skip("Not enough GPUs to run the test.")

    multiprocessing.set_start_method("spawn", force=True)
    # 60s is enough
    timeout = 60
    processes = []
    for rank in range(tp_size):
        p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
        p.start()
        processes.append((rank, p))
    for rank, p in processes:
        p.join(timeout=timeout)
        if p.is_alive():
            for r, proc in processes:
                if proc.is_alive():
                    proc.terminate()
                    proc.join()
            raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")


if __name__ == "__main__":
    test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)