test_comm_ops.py 8.79 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Test the communication operators.

5
Run `pytest tests/distributed/test_comm_ops.py`.
6
"""
7
8
9
10

from __future__ import annotations

from typing import Any, Callable
11

12
import pytest
Simon Mo's avatar
Simon Mo committed
13
import ray
14
import torch
15

16
17
18
19
20
21
22
from vllm.distributed import (
    broadcast_tensor_dict,
    get_pp_group,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
    tensor_model_parallel_reduce_scatter,
)
23

24
25
26
27
28
from ..utils import (
    init_test_distributed_environment,
    multi_gpu_test,
    multi_process_parallel,
)
29
30


Simon Mo's avatar
Simon Mo committed
31
@ray.remote(num_gpus=1, max_calls=1)
32
33
34
35
36
37
38
def all_reduce_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
39
40
41
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
42
43
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)

44
45
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
46
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
47
48
    num_elements = 8
    all_tensors = [
49
50
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
51
52
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
53
    t = all_tensors[rank % tp_size]
54
    t = tensor_model_parallel_all_reduce(t)
55
    torch.testing.assert_close(t, expected)
56
57


58
@ray.remote(num_gpus=1, max_calls=1)
59
60
61
62
63
64
65
def reduce_scatter_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
66
67
68
69
70
71
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
72
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
73
74
75

    num_elements = 8
    all_tensors = [
76
77
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
78
79
80
81
82
    ]

    index = rank % tp_size
    partition_size = num_elements // tp_size
    all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
83
    expected = all_reduce[index * partition_size : (index + 1) * partition_size]
84
85
86
87
88
    t = all_tensors[index]
    t = tensor_model_parallel_reduce_scatter(t, 0)
    torch.testing.assert_close(t, expected)


Simon Mo's avatar
Simon Mo committed
89
@ray.remote(num_gpus=1, max_calls=1)
90
91
92
93
94
95
96
def all_gather_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
97
98
99
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
100
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
101
102
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
103
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
104
105
106
107
108
109
110
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
111
112
113
114
            torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
                tensor_size
            )
            * (r + 1)
115
            for r in range(tp_size)
116
117
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
118
        t = all_tensors[rank % tp_size]
119
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
120
        torch.testing.assert_close(t, expected)
121
122


123
@ray.remote(num_gpus=1, max_calls=1)
124
125
126
127
128
129
130
def broadcast_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
131
132
133
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
134
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
135
136
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
137
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
138
    test_dict = {
139
        # device tensor
140
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
141
142
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
143
144
        "c": "test",
        "d": [1, 2, 3],
145
        "e": {"a": 1, "b": 2},
146
147
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
148
149
    }

150
    if (rank % tp_size) == 0:
151
152
153
154
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
155
156
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
157
158
159
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
160
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
161
162


163
@ray.remote(num_gpus=1, max_calls=1)
164
165
166
167
168
169
170
171
def send_recv_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
172
173
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
174
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
175
176
177
178
179
180
181
182

    test_dict = {
        # device tensor
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
        "c": "test",
        "d": [1, 2, 3],
183
        "e": {"a": 1, "b": 2},
184
185
186
187
188
189
190
191
192
193
194
195
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
    }

    if not get_pp_group().is_first_rank:
        recv_dict = get_pp_group().recv_tensor_dict()

    if not get_pp_group().is_last_rank:
        get_pp_group().send_tensor_dict(test_dict)

    if not get_pp_group().is_first_rank:
        assert len(recv_dict) == len(test_dict)
196
197
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
198
199
200
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
201
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
202
203
204


@ray.remote(num_gpus=1, max_calls=1)
205
206
207
208
209
210
211
212
def send_recv_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
213
214
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
215
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
216
217
218
219
220
221
222
223
224
225
226

    size = 64
    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

    if not get_pp_group().is_first_rank:
        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

    if not get_pp_group().is_last_rank:
        get_pp_group().send(test_tensor)

    if not get_pp_group().is_first_rank:
227
        torch.testing.assert_close(test_tensor, recv_tensor)
228
229


230
@multi_gpu_test(num_gpus=2)
231
@pytest.mark.parametrize("tp_size", [2])
232
233
234
235
@pytest.mark.parametrize(
    "test_target",
    [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
)
236
237
238
239
240
241
def test_multi_process_tensor_parallel(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, tp_size, 1, test_target)
242
243


244
@multi_gpu_test(num_gpus=2)
245
246
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
247
248
    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
)
249
250
251
252
253
254
def test_multi_process_pipeline_parallel(
    monkeypatch: pytest.MonkeyPatch,
    pp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, 1, pp_size, test_target)
255
256


257
@multi_gpu_test(num_gpus=4)
258
259
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
260
261
262
263
264
265
266
267
268
269
@pytest.mark.parametrize(
    "test_target",
    [
        send_recv_test_worker,
        send_recv_tensor_dict_test_worker,
        all_reduce_test_worker,
        all_gather_test_worker,
        broadcast_tensor_dict_test_worker,
    ],
)
270
def test_multi_process_tensor_parallel_pipeline_parallel(
271
272
273
274
275
276
    tp_size: int,
    pp_size: int,
    test_target: Callable[..., Any],
    monkeypatch: pytest.MonkeyPatch,
):
    multi_process_parallel(monkeypatch, tp_size, pp_size, test_target)