test_utils.py 4.52 KB
Newer Older
1
2
import socket

3
import pytest
4
import ray
5
import torch
6

7
import vllm.envs as envs
8
9
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
10
from vllm.utils import (cuda_device_count_stateless, get_open_port,
11
                        update_environment_variables)
12

13
14
from ..utils import multi_gpu_test

15
16

@ray.remote
17
class _CUDADeviceCountStatelessTestActor:
18
19
20
21
22

    def get_count(self):
        return cuda_device_count_stateless()

    def set_cuda_visible_devices(self, cuda_visible_devices: str):
23
24
        update_environment_variables(
            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
25
26

    def get_cuda_visible_devices(self):
27
        return envs.CUDA_VISIBLE_DEVICES
28
29
30
31
32


def test_cuda_device_count_stateless():
    """Test that cuda_device_count_stateless changes return value if
    CUDA_VISIBLE_DEVICES is changed."""
33
34
    actor = _CUDADeviceCountStatelessTestActor.options(  # type: ignore
        num_gpus=2).remote()
35
36
37
    assert len(
        sorted(ray.get(
            actor.get_cuda_visible_devices.remote()).split(","))) == 2
38
39
40
41
42
    assert ray.get(actor.get_count.remote()) == 2
    ray.get(actor.set_cuda_visible_devices.remote("0"))
    assert ray.get(actor.get_count.remote()) == 1
    ray.get(actor.set_cuda_visible_devices.remote(""))
    assert ray.get(actor.get_count.remote()) == 0
43
44


45
46
def cpu_worker(rank, WORLD_SIZE, port1, port2):
    pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
47
                                       rank=rank,
48
                                       world_size=WORLD_SIZE)
49
    if rank <= 2:
50
51
        pg2 = StatelessProcessGroup.create(
            init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
52
    data = torch.tensor([rank])
53
54
    data = pg1.broadcast_obj(data, src=2)
    assert data.item() == 2
55
    if rank <= 2:
56
57
58
59
60
        data = torch.tensor([rank + 1])
        data = pg2.broadcast_obj(data, src=2)
        assert data.item() == 3
        pg2.barrier()
    pg1.barrier()
61
62


63
def gpu_worker(rank, WORLD_SIZE, port1, port2):
64
    torch.cuda.set_device(rank)
65
    pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
66
                                       rank=rank,
67
68
69
                                       world_size=WORLD_SIZE)
    pynccl1 = PyNcclCommunicator(pg1, device=rank)
    pynccl1.disabled = False
70
    if rank <= 2:
71
72
        pg2 = StatelessProcessGroup.create(
            init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
73
74
        pynccl2 = PyNcclCommunicator(pg2, device=rank)
        pynccl2.disabled = False
75
    data = torch.tensor([rank]).cuda()
76
77
78
    pynccl1.all_reduce(data)
    pg1.barrier()
    torch.cuda.synchronize()
79
    if rank <= 2:
80
81
82
        pynccl2.all_reduce(data)
        pg2.barrier()
        torch.cuda.synchronize()
83
84
85
86
87
88
89
90
    item = data[0].item()
    print(f"rank: {rank}, item: {item}")
    if rank == 3:
        assert item == 6
    else:
        assert item == 18


91
92
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
    pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
93
94
95
96
97
98
99
100
101
102
                                       rank=rank,
                                       world_size=WORLD_SIZE)
    if rank == 2:
        pg1.broadcast_obj("secret", src=2)
    else:
        obj = pg1.broadcast_obj(None, src=2)
        assert obj == "secret"
    pg1.barrier()


103
104
def allgather_worker(rank, WORLD_SIZE, port1, port2):
    pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
105
106
107
108
109
110
111
                                       rank=rank,
                                       world_size=WORLD_SIZE)
    data = pg1.all_gather_obj(rank)
    assert data == list(range(WORLD_SIZE))
    pg1.barrier()


112
113
# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
114
@multi_gpu_test(num_gpus=4)
115
116
117
@pytest.mark.parametrize(
    "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
118
119
120
121
    port1 = get_open_port()
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", port1))
        port2 = get_open_port()
122
123
124
125
126
127
    WORLD_SIZE = 4
    from multiprocessing import get_context
    ctx = get_context("fork")
    processes = []
    for i in range(WORLD_SIZE):
        rank = i
128
129
        processes.append(
            ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
130
131
132
133
134
135
136
    for p in processes:
        p.start()
    for p in processes:
        p.join()
    for p in processes:
        assert not p.exitcode
    print("All processes finished.")