test_comm_ops.py 9.36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Test the communication operators.

5
Run `pytest tests/distributed/test_comm_ops.py`.
6
"""
7
8
9
10

from __future__ import annotations

from typing import Any, Callable
11

12
import pytest
Simon Mo's avatar
Simon Mo committed
13
import ray
14
import torch
15

16
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
17
                              tensor_model_parallel_all_gather,
18
19
                              tensor_model_parallel_all_reduce,
                              tensor_model_parallel_reduce_scatter)
20

21
from ..utils import init_test_distributed_environment, multi_process_parallel
22
23


Simon Mo's avatar
Simon Mo committed
24
@ray.remote(num_gpus=1, max_calls=1)
25
26
27
28
29
30
31
def all_reduce_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
32
33
34
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
35
36
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)

37
38
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
39
    init_test_distributed_environment(tp_size, pp_size, rank,
40
41
42
43
                                      distributed_init_port)
    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
44
        (r + 1) for r in range(tp_size)
45
46
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
47
    t = all_tensors[rank % tp_size]
48
    t = tensor_model_parallel_all_reduce(t)
49
    torch.testing.assert_close(t, expected)
50
51


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@ray.remote(num_gpus=1, max_calls=1)
def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int,
                               pp_size: int, rank: int,
                               distributed_init_port: str):
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    init_test_distributed_environment(tp_size, pp_size, rank,
                                      distributed_init_port)

    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
        (r + 1) for r in range(tp_size)
    ]

    index = rank % tp_size
    partition_size = num_elements // tp_size
    all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
    expected = all_reduce[index * partition_size:(index + 1) * partition_size]
    t = all_tensors[index]
    t = tensor_model_parallel_reduce_scatter(t, 0)
    torch.testing.assert_close(t, expected)


Simon Mo's avatar
Simon Mo committed
80
@ray.remote(num_gpus=1, max_calls=1)
81
82
83
84
85
86
87
def all_gather_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
88
89
90
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
91
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
92
93
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
94
    init_test_distributed_environment(tp_size, pp_size, rank,
95
96
97
98
99
100
101
102
103
104
                                      distributed_init_port)
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
            torch.arange(total_size, dtype=torch.float32,
                         device="cuda").reshape(tensor_size) * (r + 1)
105
            for r in range(tp_size)
106
107
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
108
        t = all_tensors[rank % tp_size]
109
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
110
        torch.testing.assert_close(t, expected)
111
112


113
@ray.remote(num_gpus=1, max_calls=1)
114
115
116
117
118
119
120
def broadcast_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
121
122
123
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
124
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
125
126
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
127
    init_test_distributed_environment(tp_size, pp_size, rank,
128
129
                                      distributed_init_port)
    test_dict = {
130
        # device tensor
131
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
132
133
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
134
135
136
137
138
139
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
140
141
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
142
143
    }

144
    if (rank % tp_size) == 0:
145
146
147
148
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
149
150
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
151
152
153
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
154
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
155
156


157
@ray.remote(num_gpus=1, max_calls=1)
158
159
160
161
162
163
164
165
def send_recv_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    init_test_distributed_environment(tp_size, pp_size, rank,
                                      distributed_init_port)

    test_dict = {
        # device tensor
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
    }

    if not get_pp_group().is_first_rank:
        recv_dict = get_pp_group().recv_tensor_dict()

    if not get_pp_group().is_last_rank:
        get_pp_group().send_tensor_dict(test_dict)

    if not get_pp_group().is_first_rank:
        assert len(recv_dict) == len(test_dict)
194
195
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
196
197
198
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
199
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
200
201
202


@ray.remote(num_gpus=1, max_calls=1)
203
204
205
206
207
208
209
210
def send_recv_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    init_test_distributed_environment(tp_size, pp_size, rank,
                                      distributed_init_port)

    size = 64
    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

    if not get_pp_group().is_first_rank:
        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

    if not get_pp_group().is_last_rank:
        get_pp_group().send(test_tensor)

    if not get_pp_group().is_first_rank:
226
        torch.testing.assert_close(test_tensor, recv_tensor)
227
228


229
230
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
231
@pytest.mark.parametrize("tp_size", [2])
232
233
234
235
@pytest.mark.parametrize("test_target", [
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
236
237
238
239
240
241
def test_multi_process_tensor_parallel(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, tp_size, 1, test_target)
242
243
244
245
246
247
248


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
249
250
251
252
253
254
def test_multi_process_pipeline_parallel(
    monkeypatch: pytest.MonkeyPatch,
    pp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, 1, pp_size, test_target)
255
256
257
258
259
260
261
262
263
264
265
266


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
    send_recv_test_worker, send_recv_tensor_dict_test_worker,
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
267
268
269
270
271
272
    tp_size: int,
    pp_size: int,
    test_target: Callable[..., Any],
    monkeypatch: pytest.MonkeyPatch,
):
    multi_process_parallel(monkeypatch, tp_size, pp_size, test_target)