test_utils.py 4.75 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import socket

5
import pytest
6
import ray
7
import torch
8

9
import vllm.envs as envs
10
11
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
12
from vllm.utils import (cuda_device_count_stateless, get_open_port,
13
                        update_environment_variables)
14

15
16
from ..utils import multi_gpu_test

17
18

@ray.remote
19
class _CUDADeviceCountStatelessTestActor:
20
21
22
23
24

    def get_count(self):
        return cuda_device_count_stateless()

    def set_cuda_visible_devices(self, cuda_visible_devices: str):
25
26
        update_environment_variables(
            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
27
28

    def get_cuda_visible_devices(self):
29
        return envs.CUDA_VISIBLE_DEVICES
30
31
32
33
34


def test_cuda_device_count_stateless():
    """Test that cuda_device_count_stateless changes return value if
    CUDA_VISIBLE_DEVICES is changed."""
35
36
    actor = _CUDADeviceCountStatelessTestActor.options(  # type: ignore
        num_gpus=2).remote()
37
38
39
    assert len(
        sorted(ray.get(
            actor.get_cuda_visible_devices.remote()).split(","))) == 2
40
41
42
43
44
    assert ray.get(actor.get_count.remote()) == 2
    ray.get(actor.set_cuda_visible_devices.remote("0"))
    assert ray.get(actor.get_count.remote()) == 1
    ray.get(actor.set_cuda_visible_devices.remote(""))
    assert ray.get(actor.get_count.remote()) == 0
45
46


47
def cpu_worker(rank, WORLD_SIZE, port1, port2):
48
49
    pg1 = StatelessProcessGroup.create(host="127.0.0.1",
                                       port=port1,
50
                                       rank=rank,
51
                                       world_size=WORLD_SIZE)
52
    if rank <= 2:
53
54
55
56
        pg2 = StatelessProcessGroup.create(host="127.0.0.1",
                                           port=port2,
                                           rank=rank,
                                           world_size=3)
57
    data = torch.tensor([rank])
58
59
    data = pg1.broadcast_obj(data, src=2)
    assert data.item() == 2
60
    if rank <= 2:
61
62
63
64
65
        data = torch.tensor([rank + 1])
        data = pg2.broadcast_obj(data, src=2)
        assert data.item() == 3
        pg2.barrier()
    pg1.barrier()
66
67


68
def gpu_worker(rank, WORLD_SIZE, port1, port2):
69
    torch.cuda.set_device(rank)
70
71
    pg1 = StatelessProcessGroup.create(host="127.0.0.1",
                                       port=port1,
72
                                       rank=rank,
73
74
                                       world_size=WORLD_SIZE)
    pynccl1 = PyNcclCommunicator(pg1, device=rank)
75
    if rank <= 2:
76
77
78
79
        pg2 = StatelessProcessGroup.create(host="127.0.0.1",
                                           port=port2,
                                           rank=rank,
                                           world_size=3)
80
        pynccl2 = PyNcclCommunicator(pg2, device=rank)
81
    data = torch.tensor([rank]).cuda()
82
83
84
    pynccl1.all_reduce(data)
    pg1.barrier()
    torch.cuda.synchronize()
85
    if rank <= 2:
86
87
88
        pynccl2.all_reduce(data)
        pg2.barrier()
        torch.cuda.synchronize()
89
90
91
92
93
94
95
96
    item = data[0].item()
    print(f"rank: {rank}, item: {item}")
    if rank == 3:
        assert item == 6
    else:
        assert item == 18


97
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
98
99
    pg1 = StatelessProcessGroup.create(host="127.0.0.1",
                                       port=port1,
100
101
102
103
104
105
106
107
108
109
                                       rank=rank,
                                       world_size=WORLD_SIZE)
    if rank == 2:
        pg1.broadcast_obj("secret", src=2)
    else:
        obj = pg1.broadcast_obj(None, src=2)
        assert obj == "secret"
    pg1.barrier()


110
def allgather_worker(rank, WORLD_SIZE, port1, port2):
111
112
    pg1 = StatelessProcessGroup.create(host="127.0.0.1",
                                       port=port1,
113
114
115
116
117
118
119
                                       rank=rank,
                                       world_size=WORLD_SIZE)
    data = pg1.all_gather_obj(rank)
    assert data == list(range(WORLD_SIZE))
    pg1.barrier()


120
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
121
@multi_gpu_test(num_gpus=4)
122
123
124
@pytest.mark.parametrize(
    "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
125
126
127
128
    port1 = get_open_port()
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", port1))
        port2 = get_open_port()
129
130
131
132
133
134
    WORLD_SIZE = 4
    from multiprocessing import get_context
    ctx = get_context("fork")
    processes = []
    for i in range(WORLD_SIZE):
        rank = i
135
136
        processes.append(
            ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
137
138
139
140
141
142
143
    for p in processes:
        p.start()
    for p in processes:
        p.join()
    for p in processes:
        assert not p.exitcode
    print("All processes finished.")