test_trt_allreduce.py 5.55 KB
Newer Older
1
import ctypes
Yi Zhang's avatar
Yi Zhang committed
2
import multiprocessing as mp
3
4
5
import random
import socket
import unittest
6
from typing import Any, List, Optional
7

8
import sgl_kernel.allreduce as custom_ops
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary


def get_open_port() -> int:
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]


def multi_process_parallel(
    world_size: int,
    test_target: Any,
) -> None:
Yi Zhang's avatar
Yi Zhang committed
33
    procs = []
34
    distributed_init_port = get_open_port()
Yi Zhang's avatar
Yi Zhang committed
35
36
37
38
39
40
41
    for i in range(world_size):
        proc = mp.Process(
            target=test_target,
            args=(world_size, i, distributed_init_port),
        )
        proc.start()
        procs.append(proc)
42

Yi Zhang's avatar
Yi Zhang committed
43
44
45
    for i in range(world_size):
        procs[i].join()
        assert procs[i].exitcode == 0
46
47
48
49
50
51


class TestCustomAllReduce(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        random.seed(42)
yizhang2077's avatar
yizhang2077 committed
52
53
        cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
        cls.world_sizes = [2, 4, 8]
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    @staticmethod
    def create_shared_buffer(
        size_in_bytes: int, group: Optional[ProcessGroup] = None
    ) -> List[int]:
        """
        Creates a shared buffer and returns a list of pointers
        representing the buffer on all processes in the group.
        """
        lib = CudaRTLibrary()
        pointer = lib.cudaMalloc(size_in_bytes)
        handle = lib.cudaIpcGetMemHandle(pointer)
        world_size = dist.get_world_size(group=group)
        rank = dist.get_rank(group=group)
        handles = [None] * world_size
        dist.all_gather_object(handles, handle, group=group)

        pointers: List[int] = []
        for i, h in enumerate(handles):
            if i == rank:
                pointers.append(pointer.value)  # type: ignore
            else:
                pointers.append(lib.cudaIpcOpenMemHandle(h).value)  # type: ignore

        return pointers

    @staticmethod
    def free_shared_buffer(
        pointers: List[int], group: Optional[ProcessGroup] = None
    ) -> None:
        rank = dist.get_rank(group=group)
        lib = CudaRTLibrary()
        lib.cudaFree(ctypes.c_void_p(pointers[rank]))

    def test_correctness(self):
        for world_size in self.world_sizes:
            if world_size > torch.cuda.device_count():
                continue
Yi Zhang's avatar
Yi Zhang committed
92
93
            multi_process_parallel(world_size, self.correctness)
            print(f"custom allreduce tp = {world_size}: OK")
94
95
96
97
98
99

    def init_custom_allreduce(self, rank, world_size, group):
        buffer_max_size = 8 * 1024 * 1024
        barrier_max_size = 8 * (24 + 2) * 8

        self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
100
101
102
        self.tmp_result_buffer_ptrs = self.create_shared_buffer(
            buffer_max_size, group=group
        )
103
104
        self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
        self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
105
        self.rank_data = torch.empty(
106
            8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
107
        )
108

109
        self.custom_ptr = custom_ops.init_custom_reduce(
110
111
            rank,
            world_size,
112
            self.rank_data,
113
            self.buffer_ptrs,
114
            self.tmp_result_buffer_ptrs,
115
116
117
118
119
            self.barrier_in_ptrs,
            self.barrier_out_ptrs,
        )

    def custom_allreduce(self, inp, out):
120
        custom_ops.custom_reduce(self.custom_ptr, inp, out)
121
122
123

    def free_custom_allreduce(self, group):
        self.free_shared_buffer(self.buffer_ptrs, group)
124
        self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
125
126
        self.free_shared_buffer(self.barrier_in_ptrs, group)
        self.free_shared_buffer(self.barrier_out_ptrs, group)
127
        custom_ops.custom_dispose(self.custom_ptr)
128
129
130

    @staticmethod
    def init_distributed_env(world_size, rank, distributed_init_port):
131
        device = torch.device("cuda:0")
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        torch.cuda.set_device(device)
        ranks = [i for i in range(world_size)]
        distributed_init_method = f"tcp://localhost:{distributed_init_port}"
        dist.init_process_group(
            backend="nccl",
            init_method=distributed_init_method,
            rank=rank,
            world_size=world_size,
        )
        group = torch.distributed.new_group(ranks, backend="gloo")
        return group

    # compare result with torch.distributed
    def correctness(self, world_size, rank, distributed_init_port):
        group = self.init_distributed_env(world_size, rank, distributed_init_port)

        self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)

        test_loop = 10
yizhang2077's avatar
yizhang2077 committed
151
        for sz in self.test_sizes:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(test_loop):
                    inp1 = torch.randint(
                        1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
                    )
                    out1 = torch.empty_like(inp1)
                    self.custom_allreduce(inp1, out1)

                    dist.all_reduce(inp1, group=group)
                    torch.testing.assert_close(out1, inp1)

        self.free_custom_allreduce(group)


if __name__ == "__main__":
    unittest.main()