test_mscclpp.py 7.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""For Now, MSCCL is only supported on TP16 and TP8 case

if [[ $RANK -eq 0  ]]; then
    ray start --block --head --port=6379 &
    python3 test_mscclpp.py;
else
    ray start --block --address=${MASTER_ADDR}:6379;
fi
"""

import itertools
import os
import random
import socket
import unittest
from contextlib import contextmanager, nullcontext
from typing import Any, List, Optional, Union

import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp

from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.communication_op import (  # noqa
    tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
    CustomAllreduce,
)
from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator
from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
from sglang.srt.distributed.parallel_state import (
    get_tensor_model_parallel_group,
    graph_capture,
    initialize_model_parallel,
    set_custom_all_reduce,
    set_mscclpp_all_reduce,
)
from sglang.srt.distributed.utils import StatelessProcessGroup
from sglang.test.test_utils import CustomTestCase


def get_open_port() -> int:
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def multi_process_parallel(
    world_size: int,
    master_addr: str,
    cls: Any,
    test_target: Any,
) -> None:

    # Using ray helps debugging the error when it failed
    # as compared to multiprocessing.
    # NOTE: We need to set working_dir for distributed tests,
    # otherwise we may get import errors on ray workers

    ray.init(log_to_driver=True)

    distributed_init_port = get_open_port()
    refs = []
    for rank in range(world_size):
        refs.append(
            test_target.remote(
                cls, world_size, master_addr, rank, distributed_init_port
            )
        )
    ray.get(refs)

    ray.shutdown()


class TestMSCCLAllReduce(CustomTestCase):
    @classmethod
    def setUpClass(cls):
        random.seed(42)
        # 1KB to 1MB
        cls.test_sizes = [512, 4096, 32768, 262144, 524288]
        cls.world_sizes = [8]
        TEST_TP16 = int(os.getenv("SGL_MSCCLPP_TEST_TP16", "0"))
        if TEST_TP16:
            cls.world_sizes = [16]
        cls.test_loop = 10

    def test_graph_allreduce(self):
        TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
        for world_size in self.world_sizes:
            if world_size not in [8, 16]:
                continue
            multi_process_parallel(
                world_size, TEST_MASTER_ADDR, self, self.graph_allreduce
            )

    def test_eager_allreduce(self):
        TEST_MASTER_ADDR = os.getenv("SGL_MSCCLPP_TEST_MASTER_ADDR", "localhost")
        for world_size in self.world_sizes:
            if world_size not in [8, 16]:
                continue
            multi_process_parallel(
                world_size, TEST_MASTER_ADDR, self, self.eager_allreduce
            )

    @ray.remote(num_gpus=1, max_calls=1)
    def graph_allreduce(self, world_size, master_addr, rank, distributed_init_port):
        del os.environ["CUDA_VISIBLE_DEVICES"]
        device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
        torch.cuda.set_device(device)
        distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
        set_mscclpp_all_reduce(True)
        set_custom_all_reduce(False)
        init_distributed_environment(
            world_size=world_size,
            rank=rank,
            distributed_init_method=distributed_init_method,
            local_rank=rank % torch.cuda.device_count(),
        )
        initialize_model_parallel(tensor_model_parallel_size=world_size)
        group = get_tensor_model_parallel_group().device_group

        # A small all_reduce for warmup.
        # this is needed because device communicators might be created lazily
        # (e.g. NCCL). This will ensure that the communicator is initialized
        # before any communication happens, so that this group can be used for
        # graph capture immediately.
        data = torch.zeros(1)
        data = data.to(device=device)
        torch.distributed.all_reduce(data, group=group)
        torch.cuda.synchronize()
        del data

        for sz in self.test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(self.test_loop):
                    with graph_capture() as graph_capture_context:
                        # use integers so result matches NCCL exactly
                        inp1 = torch.randint(
                            1,
                            16,
                            (sz,),
                            dtype=dtype,
                            device=torch.cuda.current_device(),
                        )
                        inp2 = torch.randint(
                            1,
                            16,
                            (sz,),
                            dtype=dtype,
                            device=torch.cuda.current_device(),
                        )
                        torch.cuda.synchronize()
                        graph = torch.cuda.CUDAGraph()
                        with torch.cuda.graph(
                            graph, stream=graph_capture_context.stream
                        ):
                            out1 = tensor_model_parallel_all_reduce(inp1)
                            # the input buffer is immediately modified to test
                            # synchronization
                            dist.all_reduce(inp1, group=group)
                            out2 = tensor_model_parallel_all_reduce(inp2)
                            dist.all_reduce(inp2, group=group)
                    graph.replay()
                    torch.testing.assert_close(out1, inp1)
                    torch.testing.assert_close(out2, inp2)

    @ray.remote(num_gpus=1, max_calls=1)
    def eager_allreduce(self, world_size, master_addr, rank, distributed_init_port):
        del os.environ["CUDA_VISIBLE_DEVICES"]
        device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
        torch.cuda.set_device(device)
        distributed_init_method = f"tcp://{master_addr}:{distributed_init_port}"
        set_mscclpp_all_reduce(True)
        set_custom_all_reduce(False)
        init_distributed_environment(
            world_size=world_size,
            rank=rank,
            distributed_init_method=distributed_init_method,
            local_rank=rank,
        )
        initialize_model_parallel(tensor_model_parallel_size=world_size)
        group = get_tensor_model_parallel_group().device_group

        for sz in self.test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                for _ in range(self.test_loop):
                    inp1 = torch.randint(
                        1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
                    )
                    out1 = tensor_model_parallel_all_reduce(inp1)
                    dist.all_reduce(inp1, group=group)
                    torch.testing.assert_close(out1, inp1)


if __name__ == "__main__":
    unittest.main()