test_utils.py 4.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import socket

6
import pytest
7
import ray
8
import torch
9

10
import vllm.envs as envs
11
12
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
13
from vllm.platforms import current_platform
14
from vllm.utils.network_utils import get_open_port
15
from vllm.utils.system_utils import update_environment_variables
16
from vllm.utils.torch_utils import cuda_device_count_stateless
17

18
19
from ..utils import multi_gpu_test

20
21

@ray.remote
22
class _CUDADeviceCountStatelessTestActor:
23
24
25
26
    def get_count(self):
        return cuda_device_count_stateless()

    def set_cuda_visible_devices(self, cuda_visible_devices: str):
27
        update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
28
29

    def get_cuda_visible_devices(self):
30
        return envs.CUDA_VISIBLE_DEVICES
31
32
33
34
35


def test_cuda_device_count_stateless():
    """Test that cuda_device_count_stateless changes return value if
    CUDA_VISIBLE_DEVICES is changed."""
36
37
    if current_platform.is_rocm():
        pytest.skip("Skip for ROCm because Ray uses HIP_VISIBLE_DEVICES.")
38
    actor = _CUDADeviceCountStatelessTestActor.options(  # type: ignore
39
40
41
        num_gpus=2
    ).remote()
    assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
42
43
44
45
46
    assert ray.get(actor.get_count.remote()) == 2
    ray.get(actor.set_cuda_visible_devices.remote("0"))
    assert ray.get(actor.get_count.remote()) == 1
    ray.get(actor.set_cuda_visible_devices.remote(""))
    assert ray.get(actor.get_count.remote()) == 0
47
48


49
def cpu_worker(rank, WORLD_SIZE, port1, port2):
50
51
52
    pg1 = StatelessProcessGroup.create(
        host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
    )
53
    if rank <= 2:
54
55
56
        pg2 = StatelessProcessGroup.create(
            host="127.0.0.1", port=port2, rank=rank, world_size=3
        )
57
    data = torch.tensor([rank])
58
59
    data = pg1.broadcast_obj(data, src=2)
    assert data.item() == 2
60
    if rank <= 2:
61
62
63
64
65
        data = torch.tensor([rank + 1])
        data = pg2.broadcast_obj(data, src=2)
        assert data.item() == 3
        pg2.barrier()
    pg1.barrier()
66
67


68
def gpu_worker(rank, WORLD_SIZE, port1, port2):
69
    torch.cuda.set_device(rank)
70
71
72
    pg1 = StatelessProcessGroup.create(
        host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
    )
73
    pynccl1 = PyNcclCommunicator(pg1, device=rank)
74
    if rank <= 2:
75
76
77
        pg2 = StatelessProcessGroup.create(
            host="127.0.0.1", port=port2, rank=rank, world_size=3
        )
78
        pynccl2 = PyNcclCommunicator(pg2, device=rank)
79
    data = torch.tensor([rank]).cuda()
80
81
82
    pynccl1.all_reduce(data)
    pg1.barrier()
    torch.cuda.synchronize()
83
    if rank <= 2:
84
85
86
        pynccl2.all_reduce(data)
        pg2.barrier()
        torch.cuda.synchronize()
87
88
89
90
91
92
93
94
    item = data[0].item()
    print(f"rank: {rank}, item: {item}")
    if rank == 3:
        assert item == 6
    else:
        assert item == 18


95
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
96
97
98
    pg1 = StatelessProcessGroup.create(
        host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
    )
99
100
101
102
103
104
105
106
    if rank == 2:
        pg1.broadcast_obj("secret", src=2)
    else:
        obj = pg1.broadcast_obj(None, src=2)
        assert obj == "secret"
    pg1.barrier()


107
def allgather_worker(rank, WORLD_SIZE, port1, port2):
108
109
110
    pg1 = StatelessProcessGroup.create(
        host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
    )
111
112
113
114
115
    data = pg1.all_gather_obj(rank)
    assert data == list(range(WORLD_SIZE))
    pg1.barrier()


116
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
117
@multi_gpu_test(num_gpus=4)
118
@pytest.mark.parametrize(
119
120
    "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
)
121
def test_stateless_process_group(worker):
122
123
124
125
    port1 = get_open_port()
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", port1))
        port2 = get_open_port()
126
127
    WORLD_SIZE = 4
    from multiprocessing import get_context
128

129
130
131
132
    ctx = get_context("fork")
    processes = []
    for i in range(WORLD_SIZE):
        rank = i
133
        processes.append(
134
135
            ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
        )
136
137
138
139
140
141
142
    for p in processes:
        p.start()
    for p in processes:
        p.join()
    for p in processes:
        assert not p.exitcode
    print("All processes finished.")