test_comm_ops.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Test the communication operators.

5
Run `pytest tests/distributed/test_comm_ops.py`.
6
"""
7

8
9
from collections.abc import Callable
from typing import Any
10

11
import pytest
Simon Mo's avatar
Simon Mo committed
12
import ray
13
import torch
14

15
16
17
18
19
20
21
from vllm.distributed import (
    broadcast_tensor_dict,
    get_pp_group,
    tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce,
    tensor_model_parallel_reduce_scatter,
)
22
23
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
24

25
26
27
28
29
from ..utils import (
    init_test_distributed_environment,
    multi_gpu_test,
    multi_process_parallel,
)
30
31


Simon Mo's avatar
Simon Mo committed
32
@ray.remote(num_gpus=1, max_calls=1)
33
34
35
36
37
38
39
def all_reduce_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
40
41
42
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
43
44
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)

45
46
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
47
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
48
49
    num_elements = 8
    all_tensors = [
50
51
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
52
53
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
54
    t = all_tensors[rank % tp_size]
55
    t = tensor_model_parallel_all_reduce(t)
56
    torch.testing.assert_close(t, expected)
57
58


59
@ray.remote(num_gpus=1, max_calls=1)
60
61
62
63
64
65
66
def reduce_scatter_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
67
68
69
70
71
72
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
73
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
74
75
76

    num_elements = 8
    all_tensors = [
77
78
        torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1)
        for r in range(tp_size)
79
80
81
82
83
    ]

    index = rank % tp_size
    partition_size = num_elements // tp_size
    all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
84
    expected = all_reduce[index * partition_size : (index + 1) * partition_size]
85
86
87
88
89
    t = all_tensors[index]
    t = tensor_model_parallel_reduce_scatter(t, 0)
    torch.testing.assert_close(t, expected)


Simon Mo's avatar
Simon Mo committed
90
@ray.remote(num_gpus=1, max_calls=1)
91
92
93
94
95
96
97
def all_gather_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
98
99
100
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
101
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
102
103
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
104
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
105
106
107
108
109
110
111
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
112
113
114
115
            torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(
                tensor_size
            )
            * (r + 1)
116
            for r in range(tp_size)
117
118
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
119
        t = all_tensors[rank % tp_size]
120
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
121
        torch.testing.assert_close(t, expected)
122
123


124
@ray.remote(num_gpus=1, max_calls=1)
125
126
127
128
129
130
131
def broadcast_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
132
133
134
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
135
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
136
137
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
138
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
139
    test_dict = {
140
        # device tensor
141
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
142
143
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
144
145
        "c": "test",
        "d": [1, 2, 3],
146
        "e": {"a": 1, "b": 2},
147
148
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
149
150
    }

151
    if (rank % tp_size) == 0:
152
153
154
155
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
156
157
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
158
159
160
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
161
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
162
163


164
@ray.remote(num_gpus=1, max_calls=1)
165
166
167
168
169
170
171
172
def send_recv_tensor_dict_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
173
174
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
175
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
176
177
178
179
180
181
182
183

    test_dict = {
        # device tensor
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
        "c": "test",
        "d": [1, 2, 3],
184
        "e": {"a": 1, "b": 2},
185
186
187
188
189
190
191
192
193
194
195
196
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
    }

    if not get_pp_group().is_first_rank:
        recv_dict = get_pp_group().recv_tensor_dict()

    if not get_pp_group().is_last_rank:
        get_pp_group().send_tensor_dict(test_dict)

    if not get_pp_group().is_first_rank:
        assert len(recv_dict) == len(test_dict)
197
198
        torch.testing.assert_close(recv_dict["a"], test_dict["a"])
        torch.testing.assert_close(recv_dict["b"], test_dict["b"])
199
200
201
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
202
        torch.testing.assert_close(recv_dict["f"], test_dict["f"])
203
204


205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
class _DummyWork:
    def __init__(self) -> None:
        self.wait_calls = 0

    def wait(self) -> None:
        self.wait_calls += 1


class _DummyAllGatherGroup:
    def __init__(self, world_size: int, rank_in_group: int) -> None:
        self.world_size = world_size
        self.rank_in_group = rank_in_group

    def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
        # duplicate local slice across ranks.
        assert dim == 0
        return torch.cat([t for _ in range(self.world_size)], dim=0)


def _make_group_for_unit_test(
    rank_in_group: int = 0, world_size: int = 2
) -> GroupCoordinator:
    # avoid running GroupCoordinator.__init__ (it wires up real process groups).
    g = GroupCoordinator.__new__(GroupCoordinator)
    g.world_size = world_size
    g.rank_in_group = rank_in_group
    g.ranks = list(range(world_size))
    g.use_cpu_custom_send_recv = False
    g.device_group = None
    g.cpu_group = None
    return g


def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
        t.fill_(1)
        return _DummyWork()

    monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
    monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)

    g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
    # 2 tensors so we can catch late-binding bugs in postprocess closures.
    metadata_list = [
        ("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
        ("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
    ]
    g.recv_object = lambda src=None: metadata_list  # type: ignore[method-assign]

    ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
    td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)

    assert td is not None
    assert len(handles) == 2
    assert len(postprocess) == 2

    # before postprocess, dict holds the TP slice (shape 2).
    assert td["a"].shape == torch.Size([2])
    assert td["b"].shape == torch.Size([2])

    # simulate worker-side "defer wait": wait + postprocess later.
    for handle in handles:
        handle.wait()
    for fn in postprocess:
        fn()

    # after postprocess, dict values are reconstructed to full shape (shape 4),
    # and each key should be updated independently
    assert td["a"].shape == torch.Size([4])
    assert td["b"].shape == torch.Size([4])
    torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
    torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))


def test_async_intermediate_tensors_lazy_wait() -> None:
    work = _DummyWork()
    post_calls = {"n": 0}

    def post() -> None:
        post_calls["n"] += 1

    it = AsyncIntermediateTensors(
        {"x": torch.tensor([1])},
        comm_handles=[work],
        comm_postprocess=[post],
    )

    # accessing non-tensor attributes should not trigger wait.
    assert it.kv_connector_output is None
    assert work.wait_calls == 0
    assert post_calls["n"] == 0

    # first access of `.tensors` triggers wait + postprocess.
    _ = it.tensors
    assert work.wait_calls == 1
    assert post_calls["n"] == 1

    # subsequent access should not re-wait.
    _ = it.tensors
    assert work.wait_calls == 1
    assert post_calls["n"] == 1


310
@ray.remote(num_gpus=1, max_calls=1)
311
312
313
314
315
316
317
318
def send_recv_test_worker(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    pp_size: int,
    rank: int,
    distributed_init_port: str,
):
    monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False)
319
320
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
321
    init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
322
323
324
325
326
327
328
329
330
331
332

    size = 64
    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

    if not get_pp_group().is_first_rank:
        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

    if not get_pp_group().is_last_rank:
        get_pp_group().send(test_tensor)

    if not get_pp_group().is_first_rank:
333
        torch.testing.assert_close(test_tensor, recv_tensor)
334
335


336
@multi_gpu_test(num_gpus=2)
337
@pytest.mark.parametrize("tp_size", [2])
338
339
340
341
@pytest.mark.parametrize(
    "test_target",
    [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker],
)
342
343
344
345
346
347
def test_multi_process_tensor_parallel(
    monkeypatch: pytest.MonkeyPatch,
    tp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, tp_size, 1, test_target)
348
349


350
@multi_gpu_test(num_gpus=2)
351
352
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
353
354
    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]
)
355
356
357
358
359
360
def test_multi_process_pipeline_parallel(
    monkeypatch: pytest.MonkeyPatch,
    pp_size: int,
    test_target: Callable[..., Any],
):
    multi_process_parallel(monkeypatch, 1, pp_size, test_target)
361
362


363
@multi_gpu_test(num_gpus=4)
364
365
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
366
367
368
369
370
371
372
373
374
375
@pytest.mark.parametrize(
    "test_target",
    [
        send_recv_test_worker,
        send_recv_tensor_dict_test_worker,
        all_reduce_test_worker,
        all_gather_test_worker,
        broadcast_tensor_dict_test_worker,
    ],
)
376
def test_multi_process_tensor_parallel_pipeline_parallel(
377
378
379
380
381
382
    tp_size: int,
    pp_size: int,
    test_target: Callable[..., Any],
    monkeypatch: pytest.MonkeyPatch,
):
    multi_process_parallel(monkeypatch, tp_size, pp_size, test_target)