test_fp8_blockwise_moe.py 5.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import random

import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm


def cdiv(a: int, b: int) -> int:
    return -(a // -b)


def scale_shape(shape, group_shape):
    return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )


def baseline_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: type[torch.dtype],
) -> torch.Tensor:

    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
                t = (
                    t.unsqueeze(i + 1)
                    .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
                    .flatten(i, i + 1)
                )
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

    return torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    ).to(out_dtype)


50
51
52
53
54
55
56
57
58
59
def is_sm100_supported(device=None) -> bool:
    return (torch.cuda.get_device_capability(device)[0] == 10) and (
        torch.version.cuda >= "12.8"
    )


@pytest.mark.skipif(
    not is_sm100_supported(),
    reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
    device = "cuda"
    alignment = 16
    n_g = alignment * random.randint(1, 5) * 128
    k_g = alignment * random.randint(1, 5) * 128

    scale_a_group_shape = (1, 128)
    scale_b_group_shape = (128, 128)

    expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
    problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
    layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
    layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)

    a_tensors = []
    b_tensors = []
    a_scales_tensors = []
    b_scales_tensors = []
    baseline_tensors = []

    for g in range(num_experts):
        m_g = alignment * random.randint(1, 64)
        expert_offsets[g + 1] = expert_offsets[g] + m_g
        problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)

        a_g = to_fp8(torch.randn((m_g, k_g), device=device))
        b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
        a_tensors.append(a_g)
        b_tensors.append(b_g)

        scale_a_shape = scale_shape(a_g.shape, scale_a_group_shape)
        scale_b_shape = scale_shape(b_g.shape, scale_b_group_shape)

        a_scales_tensors.append(torch.randn(scale_a_shape, device=device) * 0.001)
        b_scales_tensors.append(torch.randn(scale_b_shape, device=device) * 0.001)

        baseline = baseline_scaled_mm(
            a_g, b_g, a_scales_tensors[-1], b_scales_tensors[-1], out_dtype
        )
        baseline_tensors.append(baseline)

    a_stack = torch.empty(
        (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
    )
    b_stack = torch.empty(
        (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
    )

    for g in range(num_experts):
        a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
        b_stack[g] = b_tensors[g].t()
    b_stack = b_stack.transpose(1, 2)

    a_scale_stack = torch.empty(
        (expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
    )
    b_scale_stack = torch.empty(
        (num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
    )

    for g in range(num_experts):
        a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
        b_scale_stack[g] = b_scales_tensors[g].t()
    b_scale_stack = b_scale_stack.transpose(1, 2)

    c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
    a_strides = torch.full(
        (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
    )
    c_strides = torch.full(
        (num_experts,), c_out.stride(0), device=device, dtype=torch.int64
    )
134
135
136
137
138
139
    workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
    a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
140
141
142

    fp8_blockwise_scaled_grouped_mm(
        c_out,
143
144
145
146
147
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
148
149
150
151
152
153
154
155
156
157
158
        a_stack,
        b_stack,
        a_scale_stack,
        b_scale_stack,
        a_strides,
        a_strides,
        c_strides,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        expert_offsets[:-1],
159
        workspace,
160
161
162
163
164
165
166
167
168
169
170
    )

    for g in range(num_experts):
        baseline = baseline_tensors[g]
        actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
        torch.testing.assert_close(actual, baseline, rtol=1e-2, atol=5e-4)
        print(f"num_experts={num_experts}, out_dtype={out_dtype}: OK")


if __name__ == "__main__":
    pytest.main([__file__])