"official/nlp/bert/model_training_utils.py" did not exist on "53e3adb8ac783c2d370b9681a7ad4aeaacc7ee1f"
test_trt_allreduce.py 8.35 KB
Newer Older
1
2
3
4
5
6
import ctypes
import logging
import random
import socket
import time
import unittest
7
from typing import Any, List, Optional
8
9

import ray
10
import sgl_kernel.ops.allreduce as custom_ops
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops

from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary

logger = logging.getLogger(__name__)


def get_open_port() -> int:
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]


def multi_process_parallel(
    world_size: int,
    cls: Any,
    test_target: Any,
) -> None:
    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
    # NOTE: We need to set working_dir for distributed tests,
    # otherwise we may get import errors on ray workers
    ray.init(log_to_driver=True)

    distributed_init_port = get_open_port()
    refs = []
    for rank in range(world_size):
        refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
    ray.get(refs)

    ray.shutdown()


class TestCustomAllReduce(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        random.seed(42)
yizhang2077's avatar
yizhang2077 committed
58
59
        cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
        cls.world_sizes = [2, 4, 8]
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

    @staticmethod
    def create_shared_buffer(
        size_in_bytes: int, group: Optional[ProcessGroup] = None
    ) -> List[int]:
        """
        Creates a shared buffer and returns a list of pointers
        representing the buffer on all processes in the group.
        """
        lib = CudaRTLibrary()
        pointer = lib.cudaMalloc(size_in_bytes)
        handle = lib.cudaIpcGetMemHandle(pointer)
        world_size = dist.get_world_size(group=group)
        rank = dist.get_rank(group=group)
        handles = [None] * world_size
        dist.all_gather_object(handles, handle, group=group)

        pointers: List[int] = []
        for i, h in enumerate(handles):
            if i == rank:
                pointers.append(pointer.value)  # type: ignore
            else:
                pointers.append(lib.cudaIpcOpenMemHandle(h).value)  # type: ignore

        return pointers

    @staticmethod
    def free_shared_buffer(
        pointers: List[int], group: Optional[ProcessGroup] = None
    ) -> None:
        rank = dist.get_rank(group=group)
        lib = CudaRTLibrary()
        lib.cudaFree(ctypes.c_void_p(pointers[rank]))

    def test_correctness(self):
        for world_size in self.world_sizes:
            if world_size > torch.cuda.device_count():
                continue
            multi_process_parallel(world_size, self, self.correctness)

    def test_performance(self):
        for world_size in self.world_sizes:
            if world_size > torch.cuda.device_count():
                continue
            multi_process_parallel(world_size, self, self.performance)

    def init_custom_allreduce(self, rank, world_size, group):
        buffer_max_size = 8 * 1024 * 1024
        barrier_max_size = 8 * (24 + 2) * 8

        self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
111
112
113
        self.tmp_result_buffer_ptrs = self.create_shared_buffer(
            buffer_max_size, group=group
        )
114
115
        self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
        self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
116
        self.rank_data = torch.empty(
117
            8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
118
        )
119

120
        self.custom_ptr = custom_ops.init_custom_reduce(
121
122
            rank,
            world_size,
123
            self.rank_data,
124
            self.buffer_ptrs,
125
            self.tmp_result_buffer_ptrs,
126
127
128
129
130
            self.barrier_in_ptrs,
            self.barrier_out_ptrs,
        )

    def custom_allreduce(self, inp, out):
131
        custom_ops.custom_reduce(self.custom_ptr, inp, out)
132
133
134

    def free_custom_allreduce(self, group):
        self.free_shared_buffer(self.buffer_ptrs, group)
135
        self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
136
137
        self.free_shared_buffer(self.barrier_in_ptrs, group)
        self.free_shared_buffer(self.barrier_out_ptrs, group)
138
        custom_ops.custom_dispose(self.custom_ptr)
139
140
141
142
143
144
145
146
147
148
149

    def init_vllm_allreduce(self, rank, group):
        self.vllm_rank = rank
        self.vllm_max_size = 8 * 1024 * 1024
        self.vllm_meta_ptrs = self.create_shared_buffer(
            vllm_ops.meta_size() + self.vllm_max_size, group=group
        )
        self.vllm_buffer_ptrs = self.create_shared_buffer(
            self.vllm_max_size, group=group
        )
        self.vllm_rank_data = torch.empty(
150
            8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        )
        self.vllm_ptr = vllm_ops.init_custom_ar(
            self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
        )
        vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)

    def vllm_allreduce(self, inp, out):
        vllm_ops.all_reduce(
            self.vllm_ptr,
            inp,
            out,
            self.vllm_buffer_ptrs[self.vllm_rank],
            self.vllm_max_size,
        )

    def free_vllm_allreduce(self, group):
        vllm_ops.dispose(self.vllm_ptr)
        self.free_shared_buffer(self.vllm_meta_ptrs, group)
        self.free_shared_buffer(self.vllm_buffer_ptrs, group)

    @staticmethod
    def init_distributed_env(world_size, rank, distributed_init_port):
173
        device = torch.device("cuda:0")
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        torch.cuda.set_device(device)
        ranks = [i for i in range(world_size)]
        distributed_init_method = f"tcp://localhost:{distributed_init_port}"
        dist.init_process_group(
            backend="nccl",
            init_method=distributed_init_method,
            rank=rank,
            world_size=world_size,
        )
        group = torch.distributed.new_group(ranks, backend="gloo")
        return group

    # compare result with torch.distributed
    @ray.remote(num_gpus=1, max_calls=1)
    def correctness(self, world_size, rank, distributed_init_port):
        group = self.init_distributed_env(world_size, rank, distributed_init_port)

        self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)

        test_loop = 10
yizhang2077's avatar
yizhang2077 committed
194
        for sz in self.test_sizes:
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(test_loop):
                    inp1 = torch.randint(
                        1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
                    )
                    out1 = torch.empty_like(inp1)
                    self.custom_allreduce(inp1, out1)

                    dist.all_reduce(inp1, group=group)
                    torch.testing.assert_close(out1, inp1)

        self.free_custom_allreduce(group)

    # compare performance with vllm
    @ray.remote(num_gpus=1, max_calls=1)
    def performance(self, world_size, rank, distributed_init_port):
        group = self.init_distributed_env(world_size, rank, distributed_init_port)

        self.init_vllm_allreduce(rank, group)
        self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)

yizhang2077's avatar
yizhang2077 committed
216
        for sz in self.test_sizes:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            inp1 = torch.randint(
                1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
            )
            out1 = torch.empty_like(inp1)
            test_loop = 5000
            start = time.time()
            for _ in range(test_loop):
                self.custom_allreduce(inp1, out1)
            elapse_custom = time.time() - start

            start = time.time()
            for _ in range(test_loop):
                self.vllm_allreduce(inp1, out1)
            elapse_vllm = time.time() - start

            if rank == 0:
                logger.warning(
                    f"test_size = {sz}, world_size = {world_size}, "
235
236
                    f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
                    f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
237
238
239
240
241
242
243
244
                )

        self.free_custom_allreduce(group)
        self.free_vllm_allreduce(group)


if __name__ == "__main__":
    unittest.main()