test_comm_ops.py 4.26 KB
Newer Older
1
2
3
4
5
6
"""Test the communication operators.

Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
import pytest
import torch
Simon Mo's avatar
Simon Mo committed
7
import ray
8
9

from vllm.config import ParallelConfig
10
from vllm.utils import get_open_port
11
12
13
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_reduce,
    tensor_model_parallel_all_gather,
14
    broadcast_tensor_dict,
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
)
from vllm.worker.worker import _init_distributed_environment


def init_test_distributed_environment(pipeline_parallel_size: int,
                                      tensor_parallel_size: int, rank: int,
                                      distributed_init_port: str):
    parallel_config = ParallelConfig(pipeline_parallel_size,
                                     tensor_parallel_size,
                                     worker_use_ray=True)
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    _init_distributed_environment(parallel_config, rank,
                                  distributed_init_method)


Simon Mo's avatar
Simon Mo committed
30
@ray.remote(num_gpus=1, max_calls=1)
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
        (r + 1) for r in range(tensor_parallel_size)
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
    t = all_tensors[rank]
    t = tensor_model_parallel_all_reduce(t)
    assert torch.allclose(t, expected)


Simon Mo's avatar
Simon Mo committed
46
@ray.remote(num_gpus=1, max_calls=1)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
            torch.arange(total_size, dtype=torch.float32,
                         device="cuda").reshape(tensor_size) * (r + 1)
            for r in range(tensor_parallel_size)
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
        t = all_tensors[rank]
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
        assert torch.allclose(t, expected)


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
                                      distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    test_dict = {
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        "b": torch.arange(16, dtype=torch.int8, device="cuda"),
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
    }

    if rank == 0:
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
        assert torch.allclose(recv_dict["a"], test_dict["a"])
        assert torch.allclose(recv_dict["b"], test_dict["b"])
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]


96
97
98
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
99
100
101
102
@pytest.mark.parametrize("test_target", [
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
103
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
Simon Mo's avatar
Simon Mo committed
104
105
106
107
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
    ray.init()

108
    distributed_init_port = get_open_port()
Simon Mo's avatar
Simon Mo committed
109
    refs = []
110
    for rank in range(tensor_parallel_size):
Simon Mo's avatar
Simon Mo committed
111
112
113
114
115
116
        refs.append(
            test_target.remote(tensor_parallel_size, rank,
                               distributed_init_port))
    ray.get(refs)

    ray.shutdown()