"...python/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "be44758c1e3b1dc1a7c9aadd69bc6a068d7f40ef"
Unverified Commit 514bdeaa authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add block level high performance gemv example (#1097)

* add alloc_reducer gemv example

* test
parent f003f371
...@@ -216,75 +216,122 @@ def splitk_gemv_vectorized_tvm( ...@@ -216,75 +216,122 @@ def splitk_gemv_vectorized_tvm(
return main return main
def get_best_config(N, K): def get_block_template_configs():
iter_params = dict(
def get_configs(): block_M=[2, 4, 8, 32, 64, 128],
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32]) block_N=[2, 4, 8, 32, 64, 128],
return [ num_stages=[0, 1, 2, 3, 4],
dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values()) threads=[32, 64, 128, 256])
] return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_configs(), @tl.autotune(
warmup=3, configs=get_block_template_configs(),
rep=20, warmup=3,
) rep=20,
@jit( )
out_idx=[-1], @tl.jit(
target="auto", pass_configs={
) tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
def kernel( tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
BLOCK_N=None, },
reduce_threads=None, out_idx=[2],
)
def gemv_alloc_reducer(M,
N,
block_M=128,
block_N=128,
num_stages=2,
threads=256,
dtype: str = "float16",
accum_dtype: str = "float"):
@T.prim_func
def main(a: T.Tensor((M, N), dtype), x: T.Tensor(N, dtype), o: T.Tensor(M,
dtype)): # type: ignore
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as i0_m:
o_reducer = T.alloc_reducer(block_M, accum_dtype, replication="all")
T.clear(o_reducer)
for i0_n in T.Pipelined(T.ceildiv(N, block_N), num_stages=num_stages):
a_smem = T.alloc_shared((block_M, block_N), dtype)
T.copy(a[i0_m * block_M, i0_n * block_N], a_smem)
a_frag = T.alloc_fragment((block_M, block_N), dtype)
T.copy(a_smem, a_frag)
x_frag = T.alloc_fragment(block_N, dtype)
T.copy(x[i0_n * block_N], x_frag)
for i1_m, i1_n in T.Parallel(block_M, block_N):
o_reducer[i1_m] += a_frag[i1_m, i1_n] * x_frag[i1_n]
T.finalize_reducer(o_reducer)
T.copy(o_reducer, o[i0_m * block_M])
return main
def get_thread_template_configs():
iter_params = dict(BLOCK_N=[2, 4, 8, 32, 64, 128], reduce_threads=[4, 8, 32])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(
configs=get_thread_template_configs(),
warmup=3,
rep=20,
)
@jit(
out_idx=[-1],
target="auto",
)
def get_autotuned_kernel(
N,
K,
BLOCK_N=None,
reduce_threads=None,
):
dtype = "float16"
accum_dtype = "float"
MAX_TRANSACTION_SIZE_IN_BITS = 128
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits
BLOCK_K = reduce_threads * TILE_K
@T.prim_func
def main(
A: T.Tensor((K,), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((N,), dtype),
): ):
dtype = "float16" with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn:
accum_dtype = "float" tn = T.get_thread_binding(0)
MAX_TRANSACTION_SIZE_IN_BITS = 128 tk = T.get_thread_binding(1)
TILE_K = MAX_TRANSACTION_SIZE_IN_BITS // DataType(dtype).bits A_local = T.alloc_local((TILE_K,), dtype)
BLOCK_K = reduce_threads * TILE_K B_local = T.alloc_local((TILE_K,), dtype)
C_accum = T.alloc_local((1,), accum_dtype)
@T.prim_func
def main( T.clear(C_accum)
A: T.Tensor((K,), dtype), for bk in T.serial(T.ceildiv(K, BLOCK_K)):
B: T.Tensor((N, K), dtype), for k in T.vectorized(TILE_K):
C: T.Tensor((N,), dtype), A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k]
): B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k]
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=(BLOCK_N, reduce_threads)) as bn: for k in T.serial(TILE_K):
tn = T.get_thread_binding(0) C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype)
tk = T.get_thread_binding(1) C_reduced = T.alloc_local((1,), accum_dtype)
A_local = T.alloc_local((TILE_K,), dtype) with T.attr(
B_local = T.alloc_local((TILE_K,), dtype) T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
C_accum = T.alloc_local((1,), accum_dtype) "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
T.clear(C_accum) ):
for bk in T.serial(T.ceildiv(K, BLOCK_K)): T.evaluate(
for k in T.vectorized(TILE_K): T.tvm_thread_allreduce(
A_local[k] = A[bk * BLOCK_K + tk * TILE_K + k] T.uint32(1),
B_local[k] = B[bn * BLOCK_N + tn, bk * BLOCK_K + tk * TILE_K + k] C_accum[0],
for k in T.serial(TILE_K): True,
C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype( C_reduced[0],
accum_dtype) tk,
C_reduced = T.alloc_local((1,), accum_dtype) dtype="handle",
with T.attr( ))
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", C[bn * BLOCK_N + tn] = C_reduced[0]
T.reinterpret(T.uint64(0), dtype="handle"),
): return main
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_accum[0],
True,
C_reduced[0],
tk,
dtype="handle",
))
C[bn * BLOCK_N + tn] = C_reduced[0]
return main
return kernel()
def check_correctness_and_bench(kernel, N, K, bench_ref=True): def check_correctness_and_bench(kernel, N, K, bench_ref=True):
...@@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True): ...@@ -297,7 +344,7 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print(f"TileLang Latency: {latency} ms\n") print(f"TileLang Latency: {latency} ms\n")
def main(): def main(do_bench: bool = True):
parser = argparse.ArgumentParser(description="GEMV Example") parser = argparse.ArgumentParser(description="GEMV Example")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
...@@ -308,16 +355,23 @@ def main(): ...@@ -308,16 +355,23 @@ def main():
check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K) check_correctness_and_bench(splitk_gemv(N, K, 32, 32, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K) check_correctness_and_bench(splitk_gemv_vectorized(N, K, 2, 32), N, K)
check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K) check_correctness_and_bench(splitk_gemv_vectorized_tvm(N, K, 2, 32), N, K)
check_correctness_and_bench(gemv_alloc_reducer(N, K, block_M=128, block_N=128), N, K)
print("Test passed!") print("Test passed!")
best_result = get_best_config(N, K) if not do_bench:
best_config = best_result.config best_result = get_autotuned_kernel(N, K)
kernel = splitk_gemv_vectorized_tvm(N, K, **best_config) best_config = best_result.config
profiler = kernel.get_profiler() kernel = splitk_gemv_vectorized_tvm(N, K, **best_config)
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500) profiler = kernel.get_profiler()
print(f"Torch Latency: {latency} ms") latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
latency = profiler.do_bench(kernel, warmup=500) print(f"Torch Latency: {latency} ms")
print(f"TileLang Latency: {latency} ms\n") tilelang_thread_latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang SIMT Latency: {tilelang_thread_latency} ms\n")
kernel = gemv_alloc_reducer(N, K)
profiler = kernel.get_profiler()
tilelang_tile_latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,7 +4,7 @@ import example_gemv ...@@ -4,7 +4,7 @@ import example_gemv
def test_example_gemv(): def test_example_gemv():
example_gemv.main() example_gemv.main(do_bench=False)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment