test_comm_ops.py 8.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Test the communication operators.

5
Run `pytest tests/distributed/test_comm_ops.py`.
6
"""
7

8
9
from collections.abc import Callable
from typing import Any
10

11
import pytest
Simon Mo's avatar
Simon Mo committed
12
import ray
13
import torch
14

15
16
17
18
19
20
21
from vllm.distributed import (
    broadcast_tensor_dict,
    get_pp_group,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
    tensor_model_parallel_reduce_scatter,
)
22

23
24
25
26
27
from ..utils import (
    init_test_distributed_environment,
    multi_gpu_test,
    multi_process_parallel,
)
28
29


Simon Mo's avatar
Simon Mo committed
30
@ray.remote(num_gpus=1, max_calls=1)
31
32
33
34
35
36
37
def all_reduce_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
38
39
40
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
41
42
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)

43
44
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
45
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
46
47
    num_elements = 8
    all_tensors = [
48
49
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
50
51
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
52
    t = all_tensors[rank % tp_size]
53
    t = tensor_model_parallel_all_reduce(t)
54
    torch.testing.assert_close(t, expected)
55
56


57
@ray.remote(num_gpus=1, max_calls=1)
58
59
60
61
62
63
64
def reduce_scatter_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
65
66
67
68
69
70
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
71
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
72
73
74

    num_elements = 8
    all_tensors = [
75
76
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
77
78
79
80
81
    ]

    index = rank % tp_size
    partition_size = num_elements // tp_size
    all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
82
    expected = all_reduce[index * partition_size : (index + 1) * partition_size]
83
84
85
86
87
    t = all_tensors[index]
    t = tensor_model_parallel_reduce_scatter(t, 0)
    torch.testing.assert_close(t, expected)


Simon Mo's avatar
Simon Mo committed
88
@ray.remote(num_gpus=1, max_calls=1)
89
90
91
92
93
94
95
def all_gather_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
96
97
98
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
99
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
100
101
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
102
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
103
104
105
106
107
108
109
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
110
111
112
113
            torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
                tensor_size
            )
            * (r + 1)
114
            for r in range(tp_size)
115
116
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
117
        t = all_tensors[rank % tp_size]
118
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
119
        torch.testing.assert_close(t, expected)
120
121


122
@ray.remote(num_gpus=1, max_calls=1)
123
124
125
126
127
128
129
def broadcast_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
130
131
132
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
133
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
134
135
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
136
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
137
    test_dict = {
138
        # device tensor
139
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
140
141
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
142
143
        "c": "test",
        "d": [1, 2, 3],
144
        "e": {"a": 1, "b": 2},
145
146
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
147
148
    }

149
    if (rank % tp_size) == 0:
150
151
152
153
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
154
155
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
156
157
158
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
159
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
160
161


162
@ray.remote(num_gpus=1, max_calls=1)
163
164
165
166
167
168
169
170
def send_recv_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
171
172
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
173
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
174
175
176
177
178
179
180
181

    test_dict = {
        # device tensor
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
        "c": "test",
        "d": [1, 2, 3],
182
        "e": {"a": 1, "b": 2},
183
184
185
186
187
188
189
190
191
192
193
194
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
    }

    if not get_pp_group().is_first_rank:
        recv_dict = get_pp_group().recv_tensor_dict()

    if not get_pp_group().is_last_rank:
        get_pp_group().send_tensor_dict(test_dict)

    if not get_pp_group().is_first_rank:
        assert len(recv_dict) == len(test_dict)
195
196
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
197
198
199
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
200
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
201
202
203


@ray.remote(num_gpus=1, max_calls=1)
204
205
206
207
208
209
210
211
def send_recv_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
212
213
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
214
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
215
216
217
218
219
220
221
222
223
224
225

    size = 64
    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

    if not get_pp_group().is_first_rank:
        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

    if not get_pp_group().is_last_rank:
        get_pp_group().send(test_tensor)

    if not get_pp_group().is_first_rank:
226
        torch.testing.assert_close(test_tensor, recv_tensor)
227
228


229
@multi_gpu_test(num_gpus=2)
230
@pytest.mark.parametrize("tp_size", [2])
231
232
233
234
@pytest.mark.parametrize(
    "test_target",
    [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
)
235
236
237
238
239
240
def test_multi_process_tensor_parallel(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, tp_size, 1, test_target)
241
242


243
@multi_gpu_test(num_gpus=2)
244
245
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
246
247
    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
)
248
249
250
251
252
253
def test_multi_process_pipeline_parallel(
    monkeypatch: pytest.MonkeyPatch,
    pp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, 1, pp_size, test_target)
254
255


256
@multi_gpu_test(num_gpus=4)
257
258
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
259
260
261
262
263
264
265
266
267
268
@pytest.mark.parametrize(
    "test_target",
    [
        send_recv_test_worker,
        send_recv_tensor_dict_test_worker,
        all_reduce_test_worker,
        all_gather_test_worker,
        broadcast_tensor_dict_test_worker,
    ],
)
269
def test_multi_process_tensor_parallel_pipeline_parallel(
270
271
272
273
274
275
    tp_size: int,
    pp_size: int,
    test_target: Callable[..., Any],
    monkeypatch: pytest.MonkeyPatch,
):
    multi_process_parallel(monkeypatch, tp_size, pp_size, test_target)