test_shm_broadcast.py 3.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import multiprocessing
import random
import time

7
import numpy as np
8
9
import torch.distributed as dist

10
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
11
from vllm.distributed.utils import StatelessProcessGroup
12
from vllm.utils import get_open_port, update_environment_variables
13
14


15
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
16
17
18
19
20
21
22
    np.random.seed(seed)
    sizes = np.random.randint(1, 10_000, n)
    # on average, each array will have 5k elements
    # with int64, each array will have 40kb
    return [np.random.randint(1, 100, i) for i in sizes]


23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def distributed_run(fn, world_size):
    number_of_processes = world_size
    processes = []
    for i in range(number_of_processes):
        env = {}
        env['RANK'] = str(i)
        env['LOCAL_RANK'] = str(i)
        env['WORLD_SIZE'] = str(number_of_processes)
        env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
        env['MASTER_ADDR'] = 'localhost'
        env['MASTER_PORT'] = '12345'
        p = multiprocessing.Process(target=fn, args=(env, ))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

    for p in processes:
        assert p.exitcode == 0


def worker_fn_wrapper(fn):
    # `multiprocessing.Process` cannot accept environment variables directly
    # so we need to pass the environment variables as arguments
    # and update the environment variables in the function
    def wrapped_fn(env):
        update_environment_variables(env)
        dist.init_process_group(backend="gloo")
        fn()

    return wrapped_fn


@worker_fn_wrapper
def worker_fn():
59
60
61
62

    rank = dist.get_rank()
    if rank == 0:
        port = get_open_port()
63
        ip = '127.0.0.1'
64
        dist.broadcast_object_list([ip, port], src=0)
65
    else:
66
67
        recv = [None, None]
        dist.broadcast_object_list(recv, src=0)
68
        ip, port = recv  # type: ignore
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    stateless_pg = StatelessProcessGroup.create(ip, port, rank,
                                                dist.get_world_size())

    for pg in [dist.group.WORLD, stateless_pg]:

        writer_rank = 2
        broadcaster = MessageQueue.create_from_process_group(
            pg, 40 * 1024, 2, writer_rank)
        if rank == writer_rank:
            seed = random.randint(0, 1000)
            dist.broadcast_object_list([seed], writer_rank)
        else:
            recv = [None]
            dist.broadcast_object_list(recv, writer_rank)
            seed = recv[0]  # type: ignore

        if pg == dist.group.WORLD:
            dist.barrier()
        else:
            pg.barrier()

        # in case we find a race condition
        # print the seed so that we can reproduce the error
        print(f"Rank {rank} got seed {seed}")
        # test broadcasting with about 400MB of data
        N = 10_000
        if rank == writer_rank:
            arrs = get_arrays(N, seed)
            for x in arrs:
                broadcaster.broadcast_object(x)
                time.sleep(random.random() / 1000)
        else:
            arrs = get_arrays(N, seed)
            for x in arrs:
                y = broadcaster.broadcast_object(None)
                assert np.array_equal(x, y)
                time.sleep(random.random() / 1000)

        if pg == dist.group.WORLD:
            dist.barrier()
110
            print(f"torch distributed passed the test! Rank {rank}")
111
112
        else:
            pg.barrier()
113
            print(f"StatelessProcessGroup passed the test! Rank {rank}")
114
115
116
117


def test_shm_broadcast():
    distributed_run(worker_fn, 4)