test_utils.py 3.54 KB
Newer Older
1
import pytest
2
import ray
3
4
import torch
import torch.distributed as dist
5

6
import vllm.envs as envs
7
from vllm.distributed.utils import stateless_init_process_group
8
from vllm.utils import (cuda_device_count_stateless,
9
                        update_environment_variables)
10

11
12
from ..utils import multi_gpu_test

13
14

@ray.remote
15
class _CUDADeviceCountStatelessTestActor:
16
17
18
19
20

    def get_count(self):
        return cuda_device_count_stateless()

    def set_cuda_visible_devices(self, cuda_visible_devices: str):
21
22
        update_environment_variables(
            {"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
23
24

    def get_cuda_visible_devices(self):
25
        return envs.CUDA_VISIBLE_DEVICES
26
27
28
29
30


def test_cuda_device_count_stateless():
    """Test that cuda_device_count_stateless changes return value if
    CUDA_VISIBLE_DEVICES is changed."""
31
32
    actor = _CUDADeviceCountStatelessTestActor.options(  # type: ignore
        num_gpus=2).remote()
33
34
35
    assert len(
        sorted(ray.get(
            actor.get_cuda_visible_devices.remote()).split(","))) == 2
36
37
38
39
40
    assert ray.get(actor.get_count.remote()) == 2
    ray.get(actor.set_cuda_visible_devices.remote("0"))
    assert ray.get(actor.get_count.remote()) == 1
    ray.get(actor.set_cuda_visible_devices.remote(""))
    assert ray.get(actor.get_count.remote()) == 0
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


def cpu_worker(rank, WORLD_SIZE):
    pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
                                       rank=rank,
                                       world_size=WORLD_SIZE,
                                       backend="gloo")
    if rank <= 2:
        pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
                                           rank=rank,
                                           world_size=3,
                                           backend="gloo")
    data = torch.tensor([rank])
    dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
    if rank <= 2:
        dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
    item = data[0].item()
    print(f"rank: {rank}, item: {item}")
    if rank == 3:
        assert item == 6
    else:
        assert item == 18


def gpu_worker(rank, WORLD_SIZE):
    pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
                                       rank=rank,
                                       world_size=WORLD_SIZE,
                                       backend="nccl")
    if rank <= 2:
        pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
                                           rank=rank,
                                           world_size=3,
                                           backend="nccl")
    torch.cuda.set_device(rank)
    data = torch.tensor([rank]).cuda()
    dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
    if rank <= 2:
        dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
    item = data[0].item()
    print(f"rank: {rank}, item: {item}")
    if rank == 3:
        assert item == 6
    else:
        assert item == 18


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
    WORLD_SIZE = 4
    from multiprocessing import get_context
    ctx = get_context("fork")
    processes = []
    for i in range(WORLD_SIZE):
        rank = i
        processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
    for p in processes:
        p.start()
    for p in processes:
        p.join()
    for p in processes:
        assert not p.exitcode
    print("All processes finished.")