test_comm_ops.py 3.37 KB
Newer Older
1
2
3
4
5
"""Test the communication operators.

Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
import pytest
Simon Mo's avatar
Simon Mo committed
6
import ray
7
import torch
8
9

from vllm.model_executor.parallel_utils.communication_op import (
10
11
    broadcast_tensor_dict, tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce)
12
13
from vllm.test_utils import (init_test_distributed_environment,
                             multi_process_tensor_parallel)
14
15


Simon Mo's avatar
Simon Mo committed
16
@ray.remote(num_gpus=1, max_calls=1)
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
        (r + 1) for r in range(tensor_parallel_size)
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
    t = all_tensors[rank]
    t = tensor_model_parallel_all_reduce(t)
    assert torch.allclose(t, expected)


Simon Mo's avatar
Simon Mo committed
32
@ray.remote(num_gpus=1, max_calls=1)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
            torch.arange(total_size, dtype=torch.float32,
                         device="cuda").reshape(tensor_size) * (r + 1)
            for r in range(tensor_parallel_size)
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
        t = all_tensors[rank]
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
        assert torch.allclose(t, expected)


54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
                                      distributed_init_port: str):
    init_test_distributed_environment(1, tensor_parallel_size, rank,
                                      distributed_init_port)
    test_dict = {
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        "b": torch.arange(16, dtype=torch.int8, device="cuda"),
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
    }

    if rank == 0:
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
        assert torch.allclose(recv_dict["a"], test_dict["a"])
        assert torch.allclose(recv_dict["b"], test_dict["b"])
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]


82
83
84
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
85
86
87
88
@pytest.mark.parametrize("test_target", [
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
89
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
90
    multi_process_tensor_parallel(tensor_parallel_size, test_target)