test_mscclpp.py 4.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import multiprocessing as mp
import os
import socket
import unittest
from enum import IntEnum
from typing import Any

import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist


class MscclContextSelection(IntEnum):
    MSCCL1SHOT1NODELL = 1
    MSCCL1SHOT2NODELL = 2


def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
    device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    torch.cuda.set_device(device)
    distributed_init_method = f"tcp://localhost:{distributed_init_port}"
    dist.init_process_group(
        backend="nccl",
        init_method=distributed_init_method,
        rank=rank,
        world_size=world_size,
    )
    group = dist.group.WORLD
    cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
    if rank == 0:
        unique_id = [custom_ops.mscclpp_generate_unique_id()]
    else:
        unique_id = [None]
    dist.broadcast_object_list(
        unique_id, src=0, device=torch.device("cpu"), group=cpu_group
    )
    unique_id = unique_id[0]
    rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
    for r in range(world_size):
        rank_to_node[r] = r // 8
        rank_to_ib[r] = rank % 8
    MAX_BYTES = 2**20
    scratch = torch.empty(
        MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
    )
    put_buffer = torch.empty(
        MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
    )
    print(f"[{rank}] start mscclpp_context init")
    nranks_per_node = torch.cuda.device_count()
    selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
    mscclpp_context = custom_ops.mscclpp_init_context(
        unique_id,
        rank,
        world_size,
        scratch,
        put_buffer,
        nranks_per_node,
        rank_to_node,
        rank_to_ib,
        selection,
    )
    try:
        test_loop = 10
        for sz in test_sizes:
            for dtype in [torch.float32, torch.float16, torch.bfloat16]:
                if sz * dtype.itemsize > MAX_BYTES:
                    continue
                if rank == 0:
                    print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
                for _ in range(test_loop):
                    inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
                    inp1_ref = inp1.clone()
                    out1 = torch.empty_like(inp1)
                    custom_ops.mscclpp_allreduce(
                        mscclpp_context, inp1, out1, nthreads=512, nblocks=21
                    )
                    dist.all_reduce(inp1_ref, group=group)
                    torch.testing.assert_close(out1, inp1_ref)
    finally:
        dist.barrier(group=group)
        dist.destroy_process_group(group=group)


def get_open_port() -> int:
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("127.0.0.1", 0))
            return s.getsockname()[1]
    except OSError:
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("::1", 0))
            return s.getsockname()[1]


def multi_process_parallel(
    world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
    mp.set_start_method("spawn", force=True)

    procs = []
    distributed_init_port = get_open_port()
    for i in range(world_size):
        proc_args = (world_size, i, distributed_init_port) + target_args
        proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
        proc.start()
        procs.append(proc)

    for i in range(world_size):
        procs[i].join()
        assert (
            procs[i].exitcode == 0
        ), f"Process {i} failed with exit code {procs[i].exitcode}"


class TestMSCCLAllReduce(unittest.TestCase):
    test_sizes = [
        512,
        2560,
        4096,
        5120,
        7680,
        32768,
        262144,
        524288,
    ]
    world_sizes = [8]

    def test_correctness(self):
        for world_size in self.world_sizes:
            available_gpus = torch.cuda.device_count()
            if world_size > available_gpus:
                print(
                    f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
                )
                continue

            print(f"Running test for world_size={world_size}")
            multi_process_parallel(
                world_size, _run_correctness_worker, target_args=(self.test_sizes,)
            )
            print(f"custom allreduce tp = {world_size}: OK")


if __name__ == "__main__":
    unittest.main()