example_per_token_cast_to_fp8.py 4.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close

tilelang.disable_cache()


def per_token_cast_to_fp8(M, N, blk_m):
    dtype = "float"
    group_size = 128
    fp8_min = -448.0
    fp8_max = 448.0

    @T.prim_func
    def main(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"), X_amax: T.Tensor(
        (M, T.ceildiv(N, group_size)), dtype)):
        with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
            row = bx
            row_g_id = by
            y_local = T.alloc_fragment((blk_m, group_size), dtype)
            y_amax_local = T.alloc_fragment((blk_m,), dtype)
            y_s_local = T.alloc_fragment((blk_m,), dtype)
            y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
            y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8")

            T.annotate_layout({
                y_local:
                    T.Fragment(
                        y_local.shape,
                        forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
            })

            T.copy(
                X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
                y_local)
            T.reduce_absmax(y_local, y_amax_local, dim=1)
            for i in T.Parallel(blk_m):
                y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
                y_s_local[i] = y_amax_local[i] / fp8_max
            for i, j in T.Parallel(blk_m, group_size):
                y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
            T.copy(y_q_local, y_q_local_fp8)
            for i in T.Parallel(blk_m):
                X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
            T.copy(
                y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
                                     row_g_id * group_size:(row_g_id + 1) * group_size])

    return main


def ceil_div(x: int, y: int) -> int:
    """
    Perform ceiling division of two integers.

    Args:
        x: the dividend.
        y: the divisor.

    Returns:
        The result of the ceiling division.
    """
    return (x + y - 1) // y


def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    # this function don't support cpu tensor
    assert x.dim() == 2
    m, n = x.shape
    new_n = ceil_div(n, 128) * 128
    x_padded = torch.nn.functional.pad(x, (0, new_n - n))
    x_view = x_padded.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
    x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
    return x_fp8, (x_amax / 448.0).view(m, -1)


if __name__ == "__main__":
    M, N, blk_m = 8192, 8192, 8
    program = per_token_cast_to_fp8(M, N, blk_m)
    kernel = tilelang.compile(
        program,
        out_idx=[1, 2],
        target="cuda",
        execution_backend="cython",
        pass_configs={"tl.disable_tma_lower": True})
    print(kernel.get_kernel_source())
    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)

    x = torch.randn(M, N, device="cuda", dtype=torch.float32)

    x_fp8, x_amax = kernel(x)
    x_fp8_ref, x_amax_ref = ref_program(x)

    print("x_fp8:", x_fp8, x_fp8.shape)
    print("x_amax:", x_amax, x_amax.shape)
    print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape)
    print("x_amax_ref:", x_amax_ref, x_amax_ref.shape)

    torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
    torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
    print("All checks pass.")

    latency = profiler.do_bench(ref_program, warmup=500)
    print("Ref: {:.2f} ms".format(latency))
    latency = profiler.do_bench()
    print("Tile-lang: {:.2f} ms".format(latency))

    from tilelang.profiler import do_bench
    from example_triton_cast_to_fp8 import per_token_group_quant_fp8

    def run_triton():
        x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(
            x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
        return x_fp8_triton_, x_amax_triton_

    x_fp8_triton, x_amax_triton = run_triton()
    latency = do_bench(run_triton)
    print("Triton: {:.2f} ms".format(latency))