example_deepgemm_fp8_2xAcc.py 6.61 KB
Newer Older
1
2
3
4
from typing import Tuple

import torch
import tilelang.testing
5
import tilelang
6
7
8
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type

Taoyu Zhu's avatar
Taoyu Zhu committed
9
tilelang.testing.set_random_seed(42)
10
11


12
@tilelang.jit
13
14
15
16
def tl_gemm(
    M,
    N,
    K,
17
    block_N,
18
19
20
21
22
    in_dtype,
    out_dtype,
    accum_dtype,
):
    assert in_dtype in [
23
24
        "float8_e4m3",
    ], "Currently only float8_e4m3 is supported"
25
26
27
28
29
    assert out_dtype in [
        "bfloat16",
        "float32",
    ], "Currently only float16 and float32 are supported"

30
31
32
    group_size = 128
    block_M = 128
    block_K = 128
33
34

    A_shape = (M, K)
35
    Scales_A_shape = (M, T.ceildiv(K, group_size))
36
    B_shape = (N, K)
37
    Scales_B_shape = (T.ceildiv(N, group_size), T.ceildiv(K, group_size))
38
39
40
41
42
43
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_N, block_K)
    C_shared_shape = (block_M, block_N)

    @T.prim_func
    def main(
44
45
46
47
48
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
        scales_a: T.Tensor(Scales_A_shape, "float32"),
        scales_b: T.Tensor(Scales_B_shape, "float32"),
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_shared = T.alloc_shared(C_shared_shape, out_dtype)
            Scale_C_shared = T.alloc_shared((block_M), "float32")
            C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
            C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)

            # Improve L2 Cache
            T.use_swizzle(panel_size=10)

            T.clear(C_local)
            T.clear(C_local_accum)
            K_iters = T.ceildiv(K, block_K)
            for k in T.Pipelined(K_iters, num_stages=4):
                # Load A into shared memory
                T.copy(A[by * block_M, k * block_K], A_shared)
                # Load B into shared memory
                T.copy(B[bx * block_N, k * block_K], B_shared)
                # Load scale into shared memory
70
                Scale_B = scales_b[bx * block_N // group_size, k]
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                for i in T.Parallel(block_M):
                    Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B

                T.gemm(A_shared, B_shared, C_local, transpose_B=True)
                # Promote to enable 2xAcc
                for i, j in T.Parallel(block_M, block_N):
                    C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
                T.clear(C_local)
            # TMA store
            T.copy(C_local_accum, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return main


def ceildiv(a, b):
    return (a + b - 1) // b


def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
95
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
96
97
98
99
100


def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
101
    x_padded = torch.zeros(ceildiv(m, 128) * 128, ceildiv(n, 128) * 128, dtype=x.dtype, device=x.device)
102
103
104
105
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
106
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
    # A_scale: (M, K//128)       ==>   (M//128, K//128, 128)
    # B_scale: (N//128, K//128)  ==>   (N//128, K//128, 128)
    # A_fp8: (M, K)
    # B_fp8: (N, K)
    # out_dtype: float16 or float32
    # return C: (M, N)
    M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
    A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
    B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
    C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
    c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
    for i in range(ceildiv(M, 128)):
        for j in range(ceildiv(N, 128)):
            c_acc.zero_()
            for k in range(ceildiv(K, 128)):
                c = torch._scaled_mm(
126
127
                    A_fp8[i * 128 : (i + 1) * 128, k * 128 : (k + 1) * 128],
                    B_fp8[j * 128 : (j + 1) * 128, k * 128 : (k + 1) * 128].T,
128
129
                    scale_a=A_scales[i, k].view(128, 1).contiguous(),
                    scale_b=B_scales[j, k].view(1, 128).contiguous(),
130
131
                    out_dtype=torch.bfloat16,
                )
132
                c_acc += c.to(torch.float32)
133
            C[i * 128 : (i + 1) * 128, j * 128 : (j + 1) * 128] = c_acc.to(out_dtype)
134
135
136
137
138
139
140
141
142
143
    return C


def calc_diff(x, y):
    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


Taoyu Zhu's avatar
Taoyu Zhu committed
144
def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtype):
145
    kernel = tl_gemm(M, N, K, block_N, in_dtype, out_dtype, accum_dtype)
146
    src_code = kernel.get_kernel_source()
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    # src_code is the generated cuda source
    assert src_code is not None

    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)
    accum_dtype = map_torch_type(accum_dtype)

    A = torch.randn(M, K).to(torch.bfloat16).cuda()
    B = torch.randn(N, K).to(torch.bfloat16).cuda()
    A_fp8, A_scale = per_token_cast_to_fp8(A.clone())
    B_fp8, B_scale = per_block_cast_to_fp8(B.clone())

    C = torch.zeros(M, N, device="cuda", dtype=out_dtype)

162
    kernel(A_fp8, B_fp8, C, A_scale, B_scale)
163
164
165
166
167
168
    # Get Reference Result
    ref_c = ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype)
    diff = calc_diff(C, ref_c)
    print(f"diff: {diff}")
    assert diff < 1e-3

169
170
    profiler = kernel.get_profiler()
    latency = profiler.do_bench(warmup=25)
171
172
173
174
175
176
177
    # Ensure that the latency is not None
    assert latency is not None
    print(f"latency: {latency} ms")
    tflops = 2 * M * N * K / latency / 1e9
    print(f"tflops: {tflops}")


178
def main():
179
    assert_tl_gemm_correctness(1024, 1024, 8192, 128, "float8_e4m3", "bfloat16", "float32")
180
181


182
if __name__ == "__main__":
183
    for dtype in ["float8_e4m3"]:
184
        for out_dtype in ["bfloat16", "float32"]:
185
186
            for block_N in [16, 32, 64, 128]:
                assert_tl_gemm_correctness(1024, 1024, 8192, block_N, dtype, out_dtype, "float32")