"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "48cc33e4ceba0371a5baa3ef98f6874d906e52e1"
example_blocksparse_gemm.py 6.66 KB
Newer Older
1
2
import argparse
import itertools
3
4
import tilelang
import tilelang.language as T
5
from tilelang.engine.param import KernelParam
6
from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType
7
import torch
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import List

DEFAULT_BLOCK_M = 128
DEFAULT_BLOCK_N = 128
DEFAULT_BLOCK_K = 32
DEFAULT_NUM_STAGES = 2
DEFAULT_THREAD_NUM = 128
DEFAULT_ENABLE_RASTERIZATION = True

parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
22
parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune")
23

24
args, _ = parser.parse_known_args()
25
26
27
M, N, K = args.m, args.n, args.k
sparsity = args.sparsity
use_autotune = args.use_autotune
28
default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto)
29
30
31
32

print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}")
print(f"Target Block Sparsity: {sparsity}")
print(f"Using Autotuner: {use_autotune}\n")
33
34


35
def get_configs():
36
37
38
39
40
41
42
    block_M = [64, 128, 256]
    block_N = [64, 128, 256]
    block_K = [32, 64]
    num_stages = [1, 2, 3]
    thread_num = [128, 256]
    enable_rasterization = [True, False]

43
    _configs = list(itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
44

45
46
47
48
49
50
51
52
53
54
55
    return [
        {
            "block_M": c[0],
            "block_N": c[1],
            "block_K": c[2],
            "num_stages": c[3],
            "thread_num": c[4],
            "enable_rasteration": c[5],
        }
        for c in _configs
    ]
56
57


58
59
60
61
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
    ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
    for i in range(M // block_M):
        for j in range(N // block_N):
62
            accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
63
            for k in range(K // block_K):
64
                if BlockMask[i, j, k]:
65
66
67
68
                    accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
                        k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
                    ].to(torch.float32)
            ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    return ref_c


def supply_program(params: List[KernelParam]):
    input_tensors = []

    for p in params:
        # Check if the kernel parameter is BlockMask tensor.
        # Here, BlockMask is uniquely identified by having 3 dimensions.
        if len(p.shape) != 3:
            # For non-BlockMask tensors, use the default tensor generation logic.
            input_tensors.append(default_tensor_supply(p))
        else:
            # For BlockMask tensor, randomly set elements to True based on desired
            # sparsity level.
            block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device())
            block_mask[:, :, :] = torch.rand(p.shape) > sparsity
            input_tensors.append(block_mask)

    return input_tensors
89
90


91
92
93
@tilelang.autotune(
    configs=get_configs(),
)
94
@tilelang.jit(out_idx=[-1])
95
def blocksparse_matmul(
96
    M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32
97
):
98
99
100
    block_mask_shape = (M // block_M, N // block_N, K // block_K)

    @T.prim_func
101
    def block_sparse_matmul(
102
103
104
105
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        BlockMask: T.Tensor(block_mask_shape, "bool"),
        C: T.Tensor((M, N), dtype),
106
    ):
107
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
108
109
110
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
111
            C_shared = T.alloc_shared((block_M, block_N), dtype)
112

113
            T.use_swizzle(panel_size=10, enable=enable_rasteration)
114
            T.clear(C_local)
115
116

            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
117
118
119
120
121
                if BlockMask[by, bx, k]:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local)

122
123
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])
124

125
    return block_sparse_matmul
126
127


128
def main():
129
    # Initialize input matrices A and B on the GPU with half precision
130
131
132
133
    a = torch.randn(M, K).cuda().half()
    b = torch.randn(K, N).cuda().half()

    if args.use_autotune:
134
135
136
        # Run the autotuner to find the best kernel configuration and performance
        # get_best_config is expected to return an object containing the compiled kernel,
        # the best configuration found, latency, and reference latency.
137
        kernel = blocksparse_matmul(M, N, K)
138

139
140
        best_config = kernel.config
        best_latency = kernel.latency
141
        block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config["block_K"]
142
143
144
145

        print(f"Best Config: {best_config}")
        print(f"Sparsity Ratio: {sparsity}")
        print(f"Best Kernel Latency: {best_latency:.6f} ms")
146
    else:
147
148
149
150
151
152
153
154
155
        kernel = blocksparse_matmul(
            M,
            N,
            K,
            block_M=DEFAULT_BLOCK_M,
            block_N=DEFAULT_BLOCK_N,
            block_K=DEFAULT_BLOCK_K,
            num_stages=DEFAULT_NUM_STAGES,
            thread_num=DEFAULT_THREAD_NUM,
156
157
            enable_rasteration=DEFAULT_ENABLE_RASTERIZATION,
        )
158
159
        block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
        print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
160
161
    # Create block mask with desired sparsity
    mask_shape = (M // block_M, N // block_N, K // block_K)
162
    block_mask = torch.rand(mask_shape).cuda() > sparsity
163

164
    # Run the compiled kernel (either tuned or default) with the inputs
165
166
    c = kernel(a, b, block_mask)

167
168
    # Compute the reference result using the naive PyTorch implementation
    ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
169

170
171
172
173
174
175
    try:
        torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
        print("✅ Results are close! Verification successful.")
    except AssertionError as e:
        print("❌ Verification FAILED: Results differ significantly.")
        print(e)
176
177
178
179


if __name__ == "__main__":
    main()