example_hadamard.py 5.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tilelang
import tilelang.language as T
from tilelang.intrinsics import make_mma_swizzle_layout

import math
import argparse
import torch
from torch.nn import functional as F
import scipy


def is_pow_of_2(n):
    return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0


16
@tilelang.jit(out_idx=[1])
17
18
19
def hadamard(b, n, dtype):
    assert is_pow_of_2(n), "n must be a power of 2"
    assert 2 <= n <= 32768, "n must be in [2, 32768]"
20
    elem_size = {"float32": 4, "float16": 2, "bfloat16": 2}[dtype]
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    logN = int(math.log2(n))
    threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
    thread_elem = n // threads  # Each thread is responsible for a chunk of elements
    thread_round = int(math.log2(thread_elem))

    warps = 1 if threads <= 32 else threads // 32
    warp_round = int(math.log2(threads / warps))
    warp_size = threads // warps

    block_round = int(math.log2(warps))

    exchange_round = n * elem_size // 32768 if n * elem_size > 32768 else 1  # Suppose we use 32KB shared memory at most
    thread_elem_in_smem = thread_elem // exchange_round if exchange_round > 1 else thread_elem

    # debug log
    # print(f'{threads=}, {thread_round=}')
    # print(f'{warps=}, {warp_round=}, {warp_size=}')
    # print(f'{block_round=}')
    # print(f'{exchange_round=}')

    @T.macro
43
    def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), round: int):
44
45
46
47
        tx = T.get_thread_binding(0)
        for i in T.serial(round):
            tx_stride = 1 << i
            another_tx = tx ^ tx_stride
48
            sign = (tx >> i) & 1  # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
49
50
51

            for j in T.Pipelined(thread_elem, num_stages=1):
                buf[j] = T.tvm_warp_shuffle(
52
                    0xFFFFFFFF,  # mask of all threads
53
54
55
                    local[j],
                    another_tx % warp_size,
                    warp_size,
56
57
                    warp_size,
                )
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
                local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j])

    @T.prim_func
    def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)):
        with T.Kernel(b, threads=threads) as bx:
            local = T.alloc_local((thread_elem,), dtype)
            shared = T.alloc_shared((threads, thread_elem_in_smem), dtype)
            T.annotate_layout({shared: make_mma_swizzle_layout(shared)})
            tx = T.get_thread_binding(0)

            # 1. Load from HBM to register
            for i in T.vectorized(thread_elem):
                local[i] = A[bx, tx * thread_elem + i]

            # 2. Hadamard inside thread, n<=8
            for i in T.serial(thread_round):
                chunksize = 1 << (i + 1)
                chunknum = thread_elem // chunksize
                for j in T.serial(chunknum):
                    chunkbase = j * chunksize
                    for k in T.serial(chunksize // 2):
79
80
                        local[chunkbase + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2]
                        local[chunkbase + k + chunksize // 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

            # 3. Hadamard inside warp, n<=512
            # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
            another_val = T.alloc_local((thread_elem,), dtype)

            warp_shfl(local, another_val, warp_round)

            # 4. Hadamard inside block, n<=32768
            # Only exchange once for n<=8192, since shared mem can hold all elems
            if block_round > 0:
                warp_id = tx // warp_size
                lane_id = tx % warp_size
                src_tx = warp_id * warp_size + lane_id
                tgt_warp_id = tx % warps
                tgt_lane_id = tx // warps
                tgt_tx = tgt_warp_id * warp_size + tgt_lane_id

                # 4.1 Write to smem, swap, read from smem
                for cur_round in T.serial(exchange_round):
                    exchange_base = thread_elem_in_smem * cur_round
                    for j in T.vectorized(thread_elem_in_smem):
                        shared[src_tx, j] = local[exchange_base + j]

                    for j in T.vectorized(thread_elem_in_smem):
                        local[exchange_base + j] = shared[tgt_tx, j]

                # 4.2 Warp shuffle
                warp_shfl(local, another_val, block_round)

                # 4.3 Write to smem, swap, read from smem
                for cur_round in T.serial(exchange_round):
                    exchange_base = thread_elem_in_smem * cur_round
                    for j in T.vectorized(thread_elem_in_smem):
                        shared[tgt_tx, j] = local[exchange_base + j]

                    for j in T.vectorized(thread_elem_in_smem):
                        local[exchange_base + j] = shared[src_tx, j]

            # 5. Write back to HBM
            for i in T.vectorized(thread_elem):
                B[bx, tx * thread_elem + i] = local[i]

    return main


def ref_program(x: torch.Tensor):
    assert x.ndim == 2
    dim = x.shape[-1]
    assert is_pow_of_2(dim)
130
    return F.linear(x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
131
132
133
134


def main():
    parser = argparse.ArgumentParser()
135
136
    parser.add_argument("--batch", type=int, default=64, help="Batch size")
    parser.add_argument("--dim", type=int, default=32768, help="Dimension")
137
138
139
    args = parser.parse_args()

    B, D = args.batch, args.dim
140
141
    x = torch.randn((B, D), device="cuda")
    kernel = hadamard(B, D, "float32")
142
143
144
    y = kernel(x)
    y_ref = ref_program(x)
    torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
145
    print("All tests passed.")
146
147
148
149
150
151

    profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
    latency = profiler.do_bench(warmup=100)
    print("Tile-lang: {:.2f} ms".format(latency))


152
if __name__ == "__main__":
153
    main()