test_same_node.py 1.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import os

6
import torch
7
import torch.distributed as dist
8

9
from vllm.distributed.parallel_state import in_the_same_node_as
10
from vllm.distributed.utils import StatelessProcessGroup
11
from vllm.utils.network_utils import get_ip, get_open_port
12

13
14
15
16
17
18
19
20
21
22
23
24

def _run_test(pg):
    test_result = all(in_the_same_node_as(pg, source_rank=0))

    expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
    assert test_result == expected, f"Expected {expected}, got {test_result}"
    if pg == dist.group.WORLD:
        print("Same node test passed! when using torch distributed!")
    else:
        print("Same node test passed! when using StatelessProcessGroup!")


25
26
if __name__ == "__main__":
    dist.init_process_group(backend="gloo")
27

28
29
30
31
32
33
34
35
36
37
    rank = dist.get_rank()
    if rank == 0:
        port = get_open_port()
        ip = get_ip()
        dist.broadcast_object_list([ip, port], src=0)
    else:
        recv = [None, None]
        dist.broadcast_object_list(recv, src=0)
        ip, port = recv

38
    stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
39
40

    for pg in [dist.group.WORLD, stateless_pg]:
41
42
43
44
45
46
47
        if os.environ.get("VLLM_TEST_WITH_DEFAULT_DEVICE_SET", "0") == "1":
            default_devices = ["cpu"]
            if torch.cuda.is_available():
                default_devices.append("cuda")
            for device in default_devices:
                torch.set_default_device(device)
                _run_test(pg)
48
        else:
49
            _run_test(pg)