test_trt_allreduce.py 6.62 KB
Newer Older
1
import ctypes
Yi Zhang's avatar
Yi Zhang committed
2
import multiprocessing as mp
3
4
5
import random
import socket
import unittest
6
from typing import Any, List, Optional
7

8
import sgl_kernel.allreduce as custom_ops
9
10
11
12
13
14
15
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    ranks = list(range(world_size))
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    dist.init_process_group(
        backend="nccl",
        init_method=distributed_init_method,
        rank=rank,
        world_size=world_size,
    )
    group = dist.group.WORLD

    buffer_max_size = 8 * 1024 * 1024
    barrier_max_size = 8 * (24 + 2) * 8
    buffer_ptrs = None
    tmp_result_buffer_ptrs = None
    barrier_in_ptrs = None
    barrier_out_ptrs = None
    custom_ptr = None

    try:
        buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
            buffer_max_size, group=group
        )
        tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
            buffer_max_size, group=group
        )
        barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
            barrier_max_size, group=group
        )
        barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
            barrier_max_size, group=group
        )

        rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)

        custom_ptr = custom_ops.init_custom_reduce(
            rank,
            world_size,
            rank_data,
            buffer_ptrs,
            tmp_result_buffer_ptrs,
            barrier_in_ptrs,
            barrier_out_ptrs,
        )

        test_loop = 10
        for sz in test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(test_loop):
                    inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
                    inp1_ref = inp1.clone()
                    out1 = torch.empty_like(inp1)

                    custom_ops.custom_reduce(custom_ptr, inp1, out1)

                    dist.all_reduce(inp1_ref, group=group)

                    torch.testing.assert_close(out1, inp1_ref)

    finally:
        dist.barrier(group=group)
        if custom_ptr is not None:
            custom_ops.custom_dispose(custom_ptr)
        if buffer_ptrs:
            TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
        if tmp_result_buffer_ptrs:
            TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
        if barrier_in_ptrs:
            TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
        if barrier_out_ptrs:
            TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)

        dist.destroy_process_group(group=group)


93
94
95
96
97
98
99
def get_open_port() -> int:
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]
    except OSError:
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
100
            s.bind(("::1", 0))
101
102
103
104
            return s.getsockname()[1]


def multi_process_parallel(
105
    world_size: int, test_target: Any, target_args: tuple = ()
106
) -> None:
107
108
    mp.set_start_method("spawn", force=True)

Yi Zhang's avatar
Yi Zhang committed
109
    procs = []
110
    distributed_init_port = get_open_port()
Yi Zhang's avatar
Yi Zhang committed
111
    for i in range(world_size):
112
113
        proc_args = (world_size, i, distributed_init_port) + target_args
        proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
Yi Zhang's avatar
Yi Zhang committed
114
115
        proc.start()
        procs.append(proc)
116

Yi Zhang's avatar
Yi Zhang committed
117
118
    for i in range(world_size):
        procs[i].join()
119
120
121
        assert (
            procs[i].exitcode == 0
        ), f"Process {i} failed with exit code {procs[i].exitcode}"
122
123
124


class TestCustomAllReduce(unittest.TestCase):
125
126
    test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
    world_sizes = [2, 4, 8]
127
128
129
130
131
132
133
134

    @staticmethod
    def create_shared_buffer(
        size_in_bytes: int, group: Optional[ProcessGroup] = None
    ) -> List[int]:
        lib = CudaRTLibrary()
        pointer = lib.cudaMalloc(size_in_bytes)
        handle = lib.cudaIpcGetMemHandle(pointer)
135
136
        if group is None:
            group = dist.group.WORLD
137
138
        world_size = dist.get_world_size(group=group)
        rank = dist.get_rank(group=group)
139
140
141
142
143
144
145
146
147
148
149
150
151
152

        handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
        input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
        gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
        dist.all_gather(gathered_tensors, input_tensor, group=group)

        handles = []
        handle_type = type(handle)
        for tensor in gathered_tensors:
            bytes_list = tensor.cpu().tolist()
            bytes_data = bytes(bytes_list)
            handle_obj = handle_type()
            ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
            handles.append(handle_obj)
153
154
155
156

        pointers: List[int] = []
        for i, h in enumerate(handles):
            if i == rank:
157
                pointers.append(pointer.value)
158
            else:
159
160
161
162
163
164
165
166
                try:
                    opened_ptr = lib.cudaIpcOpenMemHandle(h)
                    pointers.append(opened_ptr.value)
                except Exception as e:
                    print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
                    raise

        dist.barrier(group=group)
167
168
169
170
171
172
        return pointers

    @staticmethod
    def free_shared_buffer(
        pointers: List[int], group: Optional[ProcessGroup] = None
    ) -> None:
173
174
        if group is None:
            group = dist.group.WORLD
175
176
        rank = dist.get_rank(group=group)
        lib = CudaRTLibrary()
177
178
179
        if pointers and len(pointers) > rank and pointers[rank] is not None:
            lib.cudaFree(ctypes.c_void_p(pointers[rank]))
        dist.barrier(group=group)
180
181
182

    def test_correctness(self):
        for world_size in self.world_sizes:
183
184
185
186
187
            available_gpus = torch.cuda.device_count()
            if world_size > available_gpus:
                print(
                    f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
                )
188
189
                continue

190
191
192
193
194
            print(f"Running test for world_size={world_size}")
            multi_process_parallel(
                world_size, _run_correctness_worker, target_args=(self.test_sizes,)
            )
            print(f"custom allreduce tp = {world_size}: OK")
195
196
197
198


if __name__ == "__main__":
    unittest.main()