test_shm_broadcast.py 3.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
import multiprocessing
import random
import time

8
import numpy as np
9
10
import torch.distributed as dist

11
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
12
from vllm.distributed.utils import StatelessProcessGroup
13
14
from vllm.utils import update_environment_variables
from vllm.utils.network_utils import get_open_port
15
16


17
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
18
19
20
21
22
23
24
    np.random.seed(seed)
    sizes = np.random.randint(1, 10_000, n)
    # on average, each array will have 5k elements
    # with int64, each array will have 40kb
    return [np.random.randint(1, 100, i) for i in sizes]


25
26
27
28
29
def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
        env = {}
30
31
32
33
34
35
36
        env["RANK"] = str(i)
        env["LOCAL_RANK"] = str(i)
        env["WORLD_SIZE"] = str(number_of_processes)
        env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
        env["MASTER_ADDR"] = "localhost"
        env["MASTER_PORT"] = "12345"
        p = multiprocessing.Process(target=fn, args=(env,))
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

    for p in processes:
        assert p.exitcode == 0


def worker_fn_wrapper(fn):
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
    def wrapped_fn(env):
        update_environment_variables(env)
        dist.init_process_group(backend="gloo")
        fn()

    return wrapped_fn


@worker_fn_wrapper
def worker_fn():
61
62
63
    rank = dist.get_rank()
    if rank == 0:
        port = get_open_port()
64
        ip = "127.0.0.1"
65
        dist.broadcast_object_list([ip, port], src=0)
66
    else:
67
68
        recv = [None, None]
        dist.broadcast_object_list(recv, src=0)
69
        ip, port = recv  # type: ignore
70

71
    stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
72
73
74
75

    for pg in [dist.group.WORLD, stateless_pg]:
        writer_rank = 2
        broadcaster = MessageQueue.create_from_process_group(
76
77
            pg, 40 * 1024, 2, writer_rank
        )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        if rank == writer_rank:
            seed = random.randint(0, 1000)
            dist.broadcast_object_list([seed], writer_rank)
        else:
            recv = [None]
            dist.broadcast_object_list(recv, writer_rank)
            seed = recv[0]  # type: ignore

        if pg == dist.group.WORLD:
            dist.barrier()
        else:
            pg.barrier()

        # in case we find a race condition
        # print the seed so that we can reproduce the error
        print(f"Rank {rank} got seed {seed}")
        # test broadcasting with about 400MB of data
        N = 10_000
        if rank == writer_rank:
            arrs = get_arrays(N, seed)
            for x in arrs:
                broadcaster.broadcast_object(x)
                time.sleep(random.random() / 1000)
        else:
            arrs = get_arrays(N, seed)
            for x in arrs:
                y = broadcaster.broadcast_object(None)
                assert np.array_equal(x, y)
                time.sleep(random.random() / 1000)

        if pg == dist.group.WORLD:
            dist.barrier()
110
            print(f"torch distributed passed the test! Rank {rank}")
111
112
        else:
            pg.barrier()
113
            print(f"StatelessProcessGroup passed the test! Rank {rank}")
114
115
116
117


def test_shm_broadcast():
    distributed_run(worker_fn, 4)