test_fp8_blockwise_moe.py 9 KB
Newer Older
1
import random
2
from typing import Tuple
3
4
5
6
7

import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm

8
9
10
11
from sglang.srt.layers.quantization.fp8_kernel import (
    per_token_group_quant_fp8_hopper_moe_mn_major,
)

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

def cdiv(a: int, b: int) -> int:
    return -(a // -b)


def scale_shape(shape, group_shape):
    return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
        dtype=torch.float8_e4m3fn
    )


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


def ceil_div(x: int, y: int) -> int:
    return (x + y - 1) // y


def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    pad_size = (128 - (n % 128)) % 128
    x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)


def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros(
        (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
    )
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
        x_view.size(0), x_view.size(2)
    )


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def baseline_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: type[torch.dtype],
) -> torch.Tensor:

    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
                t = (
                    t.unsqueeze(i + 1)
                    .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
                    .flatten(i, i + 1)
                )
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

    return torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    ).to(out_dtype)


93
94
95
96
97
98
def is_sm100_supported(device=None) -> bool:
    return (torch.cuda.get_device_capability(device)[0] == 10) and (
        torch.version.cuda >= "12.8"
    )


99
100
def is_sm90_supported(device=None) -> bool:
    return (torch.cuda.get_device_capability(device)[0] == 9) and (
101
        torch.version.cuda >= "12.3"
102
103
104
    )


105
@pytest.mark.skipif(
106
107
    not (is_sm100_supported() or is_sm90_supported()),
    reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
108
)
109
110
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
111
112
@pytest.mark.parametrize("use_custom_kernel", [True, False])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
113
    cc = torch.cuda.get_device_capability(None)[0]
114
115
    if cc == 10 and use_custom_kernel:
        return
116
117
118
119
120
121
122
123
124
125
    device = "cuda"
    alignment = 16
    n_g = alignment * random.randint(1, 5) * 128
    k_g = alignment * random.randint(1, 5) * 128

    expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
    problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
    layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
    layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)

126
    a_original_tensors = []
127
128
129
130
131
132
133
134
135
136
137
    a_tensors = []
    b_tensors = []
    a_scales_tensors = []
    b_scales_tensors = []
    baseline_tensors = []

    for g in range(num_experts):
        m_g = alignment * random.randint(1, 64)
        expert_offsets[g + 1] = expert_offsets[g] + m_g
        problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)

138
139
140
141
142
143
144
145
146
        a = torch.randn((m_g, k_g), device=device, dtype=out_dtype)  # (M, K):(K, 1)
        b = torch.randn((n_g, k_g), device=device, dtype=out_dtype).t()  # (K, N):(1, K)

        a_g, a_scale = per_token_cast_to_fp8(
            a
        )  # ag -- (M, K):(K, 1), a_scale() -- (M, k):(k, 1)
        b_g, b_scale = per_block_cast_to_fp8(
            b
        )  # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
147
        a_original_tensors.append(a)
148
149
        a_tensors.append(a_g)
        b_tensors.append(b_g)
150
151
        a_scales_tensors.append(a_scale)
        b_scales_tensors.append(b_scale)
152

153
        baseline = torch.mm(a, b)
154
        baseline_tensors.append(baseline)
155
156
157
    a_original_stack = torch.empty(
        (expert_offsets[-1], k_g), device=device, dtype=out_dtype
    )
158
159
160
161
162
163
164
    a_stack = torch.empty(
        (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
    )
    b_stack = torch.empty(
        (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
    )
    a_scale_stack = torch.empty(
165
        (expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
166
167
    )
    b_scale_stack = torch.empty(
168
        (num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32
169
170
171
    )

    for g in range(num_experts):
172
173
174
175
176
177
178
179
        # Matrix A is Row-Major.
        a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = (
            a_original_tensors[g]
        )
        a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
            g
        ]  # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
        b_stack[g] = b_tensors[g].t()  # b_stack[g] -- (N, K):(K, 1)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        if cc == 9:
            # For SM90, we need MN-Major scale factor
            # a_scales_tensors[g] -- (M, k):(k, 1)
            # a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
            a_scale_stack[
                expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
            ] = (a_scales_tensors[g].t().contiguous().view(-1))
            b_scale_stack[g] = b_scales_tensors[g]  # b_scale_stack[g] -- (k, n):(n, 1)
        elif cc == 10:
            # For SM100, we need K-Major scale factor
            # a_scales_tensors[g] -- (M, k):(k, 1)
            a_scale_stack[
                expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
            ] = a_scales_tensors[g].view(-1)
            b_scale_stack[g] = b_scales_tensors[
                g
            ]  # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
    a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
198
    b_stack = b_stack.transpose(1, 2)  # Transpose Matrix B to Column-Major.
199
200
    if cc == 10:
        b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
201

202
203
204
205
206
207
208
209
210
211
    if use_custom_kernel:
        # Replace a_stack, a_scale_stack with custom kernel output
        a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
            a_original_stack,
            expert_offsets[:-1],
            problem_sizes,
            128,
            expert_tokens_alignment=alignment,
        )

212
213
214
215
216
217
218
    c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
    a_strides = torch.full(
        (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
    )
    c_strides = torch.full(
        (num_experts,), c_out.stride(0), device=device, dtype=torch.int64
    )
219
220
221
222
223
224
    workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
    a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
    b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
225
226
227

    fp8_blockwise_scaled_grouped_mm(
        c_out,
228
229
230
231
232
        a_ptrs,
        b_ptrs,
        out_ptrs,
        a_scales_ptrs,
        b_scales_ptrs,
233
234
235
236
237
238
239
240
241
242
243
        a_stack,
        b_stack,
        a_scale_stack,
        b_scale_stack,
        a_strides,
        a_strides,
        c_strides,
        layout_sfa,
        layout_sfb,
        problem_sizes,
        expert_offsets[:-1],
244
        workspace,
245
246
247
248
249
    )

    for g in range(num_experts):
        baseline = baseline_tensors[g]
        actual = c_out[expert_offsets[g] : expert_offsets[g + 1]]
250
251
252
253
254
        diff = calc_diff(actual, baseline)
        assert diff < 0.001
        print(
            f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
        )
255
256
257
258


if __name__ == "__main__":
    pytest.main([__file__])