Commit 47ecc791 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Doc] Remove unnecessary layout annotation (#49)

* [Doc] Update documentation structure and content: add overview section, revise project name, and change theme to Furo

* [Feature] Add device-side debug printing functions and integrate into kernel interface

* lint fix

* remove debug print

* implement test for debug

* lint fix

* add some comments

* Enhance fragment design and assert fragment print

* enhance debug print

* add test for msg

* lint fix

* format

* add flash decoding exmaples

* remove comment

* test simplified
parent d86db0f9
...@@ -129,7 +129,6 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False): ...@@ -129,7 +129,6 @@ def flashattn(batch, heads, seq_len, dim, is_casual, tune=False):
scores_sum = T.alloc_fragment([block_M], accum_dtype) scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
...@@ -208,7 +207,7 @@ if __name__ == "__main__": ...@@ -208,7 +207,7 @@ if __name__ == "__main__":
if (not args.tune): if (not args.tune):
program = flashattn( program = flashattn(
batch, heads, seq_len, dim, is_casual, tune=args.tune)( batch, heads, seq_len, dim, is_casual, tune=args.tune)(
block_M=128, block_N=128, num_stages=2, threads=256) block_M=128, block_N=128, num_stages=1, threads=128)
ref_program = partial(ref_program, is_casual=is_casual) ref_program = partial(ref_program, is_casual=is_casual)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
......
...@@ -123,7 +123,6 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_ ...@@ -123,7 +123,6 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_
bid = by // heads bid = by // heads
sid = bz sid = bz
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared) T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
......
...@@ -300,18 +300,4 @@ def test_pad_f16f16f32_nn(): ...@@ -300,18 +300,4 @@ def test_pad_f16f16f32_nn():
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() tilelang.testing.main()
run_gemm(
512,
1024,
768,
False,
True,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm():
# GEMM tests for float16
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nn
run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_tn
run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nt
run_gemm(512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256,
32, 2) # pad_aligned_f16f16f16_nn
run_gemm(512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256,
32, 2) # pad_f16f16f16_nn
# GEMM tests for mixed precision (float16 + float32)
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128,
32) # f16f16f32_nn
run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64,
32) # pad_f16f16f32_nn
# GEMM tests for bfloat16
run_gemm(512, 1024, 768, False, False, "bfloat16", "bfloat16", "float32", 128, 128,
32) # bf16bf16f32_nn
# GEMM tests for float32
run_gemm(512, 1024, 768, False, False, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_nn
run_gemm(512, 1024, 768, False, True, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_nt
run_gemm(512, 1024, 768, True, False, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_tn
# GEMM tests for float64
run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32,
16) # f64f64f64_nt
# GEMM tests for int8
run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_nn
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_nt
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_tn
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A_shared, A_frag)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A_shared, A_frag)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_frag, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = tl.lower(program)
profiler = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rs():
# GEMM tests for float16
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -45,7 +45,7 @@ class Fragment(Layout): ...@@ -45,7 +45,7 @@ class Fragment(Layout):
thread_replicate = None thread_replicate = None
forward_thread = forward_thread_fn(*vars) forward_thread = forward_thread_fn(*vars)
if not isinstance(forward_index, tvm.ir.container.Array) and forward_index is not None: if forward_index is not None and not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index] forward_index = [forward_index]
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment