test_comm_ops.py 4.26 KB
Newer Older
1
2
"""Test the communication operators.

3
Run `pytest tests/distributed/test_comm_ops.py`.
4
"""
5
6
import os

7
import pytest
Simon Mo's avatar
Simon Mo committed
8
import ray
9
import torch
10
11

from vllm.model_executor.parallel_utils.communication_op import (
12
13
    broadcast_tensor_dict, tensor_model_parallel_all_gather,
    tensor_model_parallel_all_reduce)
14
15
from vllm.test_utils import (init_test_distributed_environment,
                             multi_process_tensor_parallel)
16
17


Simon Mo's avatar
Simon Mo committed
18
@ray.remote(num_gpus=1, max_calls=1)
19
20
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
21
22
23
24
25
26
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
27
    init_test_distributed_environment(1, tensor_parallel_size, rank,
28
29
30
31
32
33
34
35
36
37
38
39
                                      distributed_init_port)
    num_elements = 8
    all_tensors = [
        torch.arange(num_elements, dtype=torch.float32, device="cuda") *
        (r + 1) for r in range(tensor_parallel_size)
    ]
    expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
    t = all_tensors[rank]
    t = tensor_model_parallel_all_reduce(t)
    assert torch.allclose(t, expected)


Simon Mo's avatar
Simon Mo committed
40
@ray.remote(num_gpus=1, max_calls=1)
41
42
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
                           distributed_init_port: str):
43
44
45
46
47
48
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
49
    init_test_distributed_environment(1, tensor_parallel_size, rank,
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                                      distributed_init_port)
    num_dimensions = 3
    tensor_size = list(range(2, num_dimensions + 2))
    total_size = 1
    for s in tensor_size:
        total_size *= s
    for all_gather_dimension in range(num_dimensions):
        all_tensors = [
            torch.arange(total_size, dtype=torch.float32,
                         device="cuda").reshape(tensor_size) * (r + 1)
            for r in range(tensor_parallel_size)
        ]
        expected = torch.cat(all_tensors, dim=all_gather_dimension)
        t = all_tensors[rank]
        t = tensor_model_parallel_all_gather(t, all_gather_dimension)
        assert torch.allclose(t, expected)


68
69
70
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
                                      distributed_init_port: str):
71
72
73
74
75
76
    # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
    # so that each worker can see all the GPUs
    # they will be able to set the device to the correct GPU
    del os.environ["CUDA_VISIBLE_DEVICES"]
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
77
    init_test_distributed_environment(1, tensor_parallel_size, rank,
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                                      distributed_init_port)
    test_dict = {
        "a": torch.arange(8, dtype=torch.float32, device="cuda"),
        "b": torch.arange(16, dtype=torch.int8, device="cuda"),
        "c": "test",
        "d": [1, 2, 3],
        "e": {
            "a": 1,
            "b": 2
        },
    }

    if rank == 0:
        broadcast_tensor_dict(test_dict, src=0)
    else:
        recv_dict = broadcast_tensor_dict(src=0)
        assert len(recv_dict) == len(test_dict)
        assert torch.allclose(recv_dict["a"], test_dict["a"])
        assert torch.allclose(recv_dict["b"], test_dict["b"])
        assert recv_dict["c"] == test_dict["c"]
        assert recv_dict["d"] == test_dict["d"]
        assert recv_dict["e"] == test_dict["e"]


102
103
104
@pytest.mark.skipif(torch.cuda.device_count() < 2,
                    reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
105
106
107
108
@pytest.mark.parametrize("test_target", [
    all_reduce_test_worker, all_gather_test_worker,
    broadcast_tensor_dict_test_worker
])
109
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
110
    multi_process_tensor_parallel(tensor_parallel_size, test_target)