test_fp8_gemm.py 2.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import unittest

import torch
from sgl_kernel import fp8_scaled_mm


def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
    o = torch.matmul(a.to(torch.float32), b.to(torch.float32))

    o = o.to(torch.float32)
    temp1 = o * scale_a.view(-1, 1)
    temp2 = temp1 * scale_b.view(1, -1)
    final = temp2.to(out_dtype)
    if bias is not None:
        final = final + bias.view(1, -1)

    return final


class TestFp8Gemm(unittest.TestCase):
    def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

        a_fp32 = (
            (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
        )
        a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

        b_fp32 = (
            (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
        )
        b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

        scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
        scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
        if with_bias:
            bias = torch.randn((N,), device=device, dtype=out_dtype)
        else:
            bias = None
        o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
        b_fp8 = b_fp8.t()
        o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
        o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
        rtol = 0.02
        atol = 1
        torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
        print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")

    def test_accuracy(self):
        Ms = [1, 128, 512, 1024, 4096]
        Ns = [16, 128, 512, 1024, 4096]
        Ks = [512, 1024, 4096, 8192, 16384]
        bias_opts = [True, False]
        out_dtypes = [torch.bfloat16, torch.float16]
        for M in Ms:
            for N in Ns:
                for K in Ks:
                    for with_bias in bias_opts:
                        for out_dtype in out_dtypes:
                            self._test_accuracy_once(
                                M, N, K, with_bias, out_dtype, "cuda"
                            )


if __name__ == "__main__":
    unittest.main()