test_fp8_blockwise_gemm.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import unittest
from typing import Optional, Type

import torch
from sgl_kernel import fp8_blockwise_scaled_mm


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


def scale_shape(shape, group_shape):
    assert len(shape) == len(group_shape)
    return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


def baseline_scaled_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: Type[torch.dtype],
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

    # We treat N-dimensional group scaling as extended numpy-style broadcasting
    # in numpy simply stretches dimensions with an extent of 1 to match the
    # the target shape by repeating the data along that dimension (broadcasting)
    # , we extend these semantics to say if the extent of a dimension in the
    # source shape is not 1 and does not match the target shape we repeat each
    # element along that dimension src_shape[dim] // target_shape[dim] times
    # example if we have:
    #       a = [[1, 2], and target_shape = (2, 4)
    #            [3, 4]]
    # then we would expand a to:
    #       a = [[1, 1, 2, 2],
    #            [3, 3, 4, 4]]
    # NOTE this function this function does not explicitly broadcast dimensions
    # with an extent of 1, since this can be done implicitly by pytorch
    def group_broadcast(t, shape):
        for i, s in enumerate(shape):
            if t.shape[i] != s and t.shape[i] != 1:
                assert s % t.shape[i] == 0
                t = (
                    t.unsqueeze(i + 1)
                    .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :])
                    .flatten(i, i + 1)
                )
        return t

    scale_a = group_broadcast(scale_a, a.shape)
    scale_b = group_broadcast(scale_b, b.shape)

    output = torch.mm(
        (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
    ).to(out_dtype)

    if bias is not None:
        output = output + bias

    return output


class TestFp8Gemm(unittest.TestCase):
    def _test_accuracy_once(self, M, N, K, out_dtype, device):
        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

        a_fp32 = (
            (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
        )
        a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

        b_fp32 = (
            (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
        )
        b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()

        scale_a_group_shape = (1, 128)
        scale_b_group_shape = (128, 128)
        scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
        scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)

        scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
        scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
        scale_a = scale_a.t().contiguous().t()
        scale_b = scale_b.t().contiguous().t()

        o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
        o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
        o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)

        rtol = 0.02
        atol = 1
        torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
        print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")

    def test_accuracy(self):
        Ms = [1, 128, 512, 1024, 4096]
        Ns = [128, 512, 1024, 4096]
        Ks = [512, 1024, 4096, 8192, 16384]
        out_dtypes = [torch.bfloat16, torch.float16]
        for M in Ms:
            for N in Ns:
                for K in Ks:
                    for out_dtype in out_dtypes:
                        self._test_accuracy_once(M, N, K, out_dtype, "cuda")


if __name__ == "__main__":
    unittest.main()