test_comm_ops.py 7.46 KB
Newer Older
1
2
"""Test the communication operators.

3
Run `pytest tests/distributed/test_comm_ops.py`.
4
"""
5
6
import os

7
import pytest
Simon Mo's avatar
Simon Mo committed
8
import ray
9
import torch
10

11
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
12
13
                              tensor_model_parallel_all_gather,
                              tensor_model_parallel_all_reduce)
14

15
from ..utils import init_test_distributed_environment, multi_process_parallel
16
17


Simon Mo's avatar
Simon Mo committed
18
@ray.remote(num_gpus=1, max_calls=1)
19
def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
20
                           distributed_init_port: str):
21
22
23
24
25
26
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
27
    init_test_distributed_environment(tp_size, pp_size, rank,
28
29
30
31
                                      distributed_init_port)
    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
32
        (r + 1) for r in range(tp_size)
33
34
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
35
    t = all_tensors[rank % tp_size]
36
37
38
39
    t = tensor_model_parallel_all_reduce(t)
    assert torch.allclose(t, expected)


Simon Mo's avatar
Simon Mo committed
40
@ray.remote(num_gpus=1, max_calls=1)
41
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
42
                           distributed_init_port: str):
43
44
45
46
47
48
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
49
    init_test_distributed_environment(tp_size, pp_size, rank,
50
51
52
53
54
55
56
57
58
59
                                      distributed_init_port)
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
            torch.arange(total_size, dtype=torch.float32,
                         device="cuda").reshape(tensor_size) * (r + 1)
60
            for r in range(tp_size)
61
62
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
63
        t = all_tensors[rank % tp_size]
64
65
66
67
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
        assert torch.allclose(t, expected)


68
@ray.remote(num_gpus=1, max_calls=1)
69
def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
70
                                      distributed_init_port: str):
71
72
73
74
75
76
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
77
    init_test_distributed_environment(tp_size, pp_size, rank,
78
79
                                      distributed_init_port)
    test_dict = {
80
        # device tensor
81
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
82
83
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
84
85
86
87
88
89
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
90
91
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
92
93
    }

94
    if (rank % tp_size) == 0:
95
96
97
98
99
100
101
102
103
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
        assert torch.allclose(recv_dict["a"], test_dict["a"])
        assert torch.allclose(recv_dict["b"], test_dict["b"])
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
104
        assert torch.allclose(recv_dict["f"], test_dict["f"])
105
106


107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
                                      distributed_init_port: str):
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    init_test_distributed_environment(tp_size, pp_size, rank,
                                      distributed_init_port)

    test_dict = {
        # device tensor
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        # CPU tensor
        "b": torch.arange(16, dtype=torch.int8, device="cpu"),
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
        # empty tensor
        "f": torch.tensor([], dtype=torch.float32, device="cuda"),
    }

    if not get_pp_group().is_first_rank:
        recv_dict = get_pp_group().recv_tensor_dict()

    if not get_pp_group().is_last_rank:
        get_pp_group().send_tensor_dict(test_dict)

    if not get_pp_group().is_first_rank:
        assert len(recv_dict) == len(test_dict)
        assert torch.allclose(recv_dict["a"], test_dict["a"])
        assert torch.allclose(recv_dict["b"], test_dict["b"])
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]
        assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
                          distributed_init_port: str):
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    init_test_distributed_environment(tp_size, pp_size, rank,
                                      distributed_init_port)

    size = 64
    test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

    if not get_pp_group().is_first_rank:
        recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

    if not get_pp_group().is_last_rank:
        get_pp_group().send(test_tensor)

    if not get_pp_group().is_first_rank:
        assert torch.allclose(test_tensor, recv_tensor)


169
170
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
171
@pytest.mark.parametrize("tp_size", [2])
172
173
174
175
@pytest.mark.parametrize("test_target", [
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
176
def test_multi_process_tensor_parallel(tp_size, test_target):
177
178
179
180
181
182
183
184
185
186
    multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
    "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
    multi_process_parallel(1, pp_size, test_target)
187
188
189
190
191
192
193
194
195
196
197
198
199
200


@pytest.mark.skipif(torch.cuda.device_count() < 4,
                    reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
    send_recv_test_worker, send_recv_tensor_dict_test_worker,
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
        tp_size, pp_size, test_target):
    multi_process_parallel(tp_size, pp_size, test_target)