test_custom_allreduce.py 6.03 KB
Newer Older
1
import ctypes
Yi Zhang's avatar
Yi Zhang committed
2
import multiprocessing as mp
3
4
5
import random
import socket
import unittest
6
from typing import Any, List, Optional
7

8
import sgl_kernel.allreduce as custom_ops
9
10
11
12
13
14
15
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary


16
17
18
19
20
21
22
23
24
25
26
27
28
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    dist.init_process_group(
        backend="nccl",
        init_method=distributed_init_method,
        rank=rank,
        world_size=world_size,
    )
    group = dist.group.WORLD

    try:
29
30
31
32
        device = torch.device(f"cuda:{rank}")
        max_size = 8192 * 1024
        meta_ptrs = TestCustomAllReduce.create_shared_buffer(
            custom_ops.meta_size() + max_size, group=group
33
34
35
        )

        rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
36
        buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
37

38
39
        custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
        custom_ops.register_buffer(custom_ptr, buffer_ptrs)
40
41
42
43
44
45
46
47
48

        test_loop = 10
        for sz in test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(test_loop):
                    inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
                    inp1_ref = inp1.clone()
                    out1 = torch.empty_like(inp1)

49
50
51
                    custom_ops.all_reduce(
                        custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
                    )
52
53
54
55
56
57
58
59

                    dist.all_reduce(inp1_ref, group=group)

                    torch.testing.assert_close(out1, inp1_ref)

    finally:
        dist.barrier(group=group)
        if custom_ptr is not None:
60
            custom_ops.dispose(custom_ptr)
61
62
        if buffer_ptrs:
            TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
63
64
        if meta_ptrs:
            TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
65
66
67
68

        dist.destroy_process_group(group=group)


69
70
71
72
73
74
75
def get_open_port() -> int:
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]
    except OSError:
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
76
            s.bind(("::1", 0))
77
78
79
80
            return s.getsockname()[1]


def multi_process_parallel(
81
    world_size: int, test_target: Any, target_args: tuple = ()
82
) -> None:
83
84
    mp.set_start_method("spawn", force=True)

Yi Zhang's avatar
Yi Zhang committed
85
    procs = []
86
    distributed_init_port = get_open_port()
Yi Zhang's avatar
Yi Zhang committed
87
    for i in range(world_size):
88
89
        proc_args = (world_size, i, distributed_init_port) + target_args
        proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
Yi Zhang's avatar
Yi Zhang committed
90
91
        proc.start()
        procs.append(proc)
92

Yi Zhang's avatar
Yi Zhang committed
93
94
    for i in range(world_size):
        procs[i].join()
95
96
97
        assert (
            procs[i].exitcode == 0
        ), f"Process {i} failed with exit code {procs[i].exitcode}"
98
99
100


class TestCustomAllReduce(unittest.TestCase):
101
102
103
104
105
106
107
108
109
110
111
112
    test_sizes = [
        512,
        2560,
        4096,
        5120,
        7680,
        32768,
        262144,
        524288,
        1048576,
        2097152,
    ]
113
    world_sizes = [2, 4, 8]
114
115
116
117
118
119
120
121

    @staticmethod
    def create_shared_buffer(
        size_in_bytes: int, group: Optional[ProcessGroup] = None
    ) -> List[int]:
        lib = CudaRTLibrary()
        pointer = lib.cudaMalloc(size_in_bytes)
        handle = lib.cudaIpcGetMemHandle(pointer)
122
123
        if group is None:
            group = dist.group.WORLD
124
125
        world_size = dist.get_world_size(group=group)
        rank = dist.get_rank(group=group)
126
127
128
129
130
131
132
133
134
135
136
137
138
139

        handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
        input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
        gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
        dist.all_gather(gathered_tensors, input_tensor, group=group)

        handles = []
        handle_type = type(handle)
        for tensor in gathered_tensors:
            bytes_list = tensor.cpu().tolist()
            bytes_data = bytes(bytes_list)
            handle_obj = handle_type()
            ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
            handles.append(handle_obj)
140
141
142
143

        pointers: List[int] = []
        for i, h in enumerate(handles):
            if i == rank:
144
                pointers.append(pointer.value)
145
            else:
146
147
148
149
150
151
152
153
                try:
                    opened_ptr = lib.cudaIpcOpenMemHandle(h)
                    pointers.append(opened_ptr.value)
                except Exception as e:
                    print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
                    raise

        dist.barrier(group=group)
154
155
156
157
158
159
        return pointers

    @staticmethod
    def free_shared_buffer(
        pointers: List[int], group: Optional[ProcessGroup] = None
    ) -> None:
160
161
        if group is None:
            group = dist.group.WORLD
162
163
        rank = dist.get_rank(group=group)
        lib = CudaRTLibrary()
164
165
166
        if pointers and len(pointers) > rank and pointers[rank] is not None:
            lib.cudaFree(ctypes.c_void_p(pointers[rank]))
        dist.barrier(group=group)
167
168
169

    def test_correctness(self):
        for world_size in self.world_sizes:
170
171
172
173
174
            available_gpus = torch.cuda.device_count()
            if world_size > available_gpus:
                print(
                    f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
                )
175
176
                continue

177
178
179
180
181
            print(f"Running test for world_size={world_size}")
            multi_process_parallel(
                world_size, _run_correctness_worker, target_args=(self.test_sizes,)
            )
            print(f"custom allreduce tp = {world_size}: OK")
182
183
184
185


if __name__ == "__main__":
    unittest.main()