"...composable_kernel_onnxruntime.git" did not exist on "d2217f3040ef0219ec5f22fa5fd94ab66d9997ae"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -4,12 +4,11 @@ import tilelang.language as T ...@@ -4,12 +4,11 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -27,9 +27,9 @@ def matmul( ...@@ -27,9 +27,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -89,7 +89,8 @@ def run_gemm_ss( ...@@ -89,7 +89,8 @@ def run_gemm_ss(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
...@@ -159,9 +160,9 @@ def matmul_rs( ...@@ -159,9 +160,9 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -169,9 +170,11 @@ def matmul_rs( ...@@ -169,9 +170,11 @@ def matmul_rs(
A_frag = T.alloc_fragment(A_frag_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
}) A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A: if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(A[k * block_K, by * block_M], A_shared)
...@@ -225,7 +228,8 @@ def run_gemm_rs( ...@@ -225,7 +228,8 @@ def run_gemm_rs(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B): def ref_program(A, B):
...@@ -294,9 +298,9 @@ def matmul_sr( ...@@ -294,9 +298,9 @@ def matmul_sr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -304,9 +308,11 @@ def matmul_sr( ...@@ -304,9 +308,11 @@ def matmul_sr(
B_frag = T.alloc_fragment(B_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
T.annotate_layout({ T.annotate_layout(
B_shared: tilelang.layout.make_swizzled_layout(B_shared), {
}) B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A: if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(A[k * block_K, by * block_M], A_shared)
...@@ -360,7 +366,8 @@ def run_gemm_sr( ...@@ -360,7 +366,8 @@ def run_gemm_sr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B): def ref_program(A, B):
...@@ -430,9 +437,9 @@ def matmul_rr( ...@@ -430,9 +437,9 @@ def matmul_rr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -441,10 +448,12 @@ def matmul_rr( ...@@ -441,10 +448,12 @@ def matmul_rr(
B_frag = T.alloc_fragment(B_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
B_shared: tilelang.layout.make_swizzled_layout(B_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}) B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A: if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(A[k * block_K, by * block_M], A_shared)
...@@ -499,7 +508,8 @@ def run_gemm_rr( ...@@ -499,7 +508,8 @@ def run_gemm_rr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B): def ref_program(A, B):
......
...@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): ...@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low, high = (0, 4) if is_unsigned else (-2, 2) low, high = (0, 4) if is_unsigned else (-2, 2)
else: else:
low, high = (0, 128) if is_unsigned else (-64, 64) low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse( A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
M, B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
else: else:
A = randn_semi_sparse( A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype))
M, K, dtype=torch.float32, device='cuda', B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
transposed=trans_A).to(map_torch_type(in_dtype))
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B return A, B
...@@ -69,24 +53,22 @@ def matmul_sp_sm90( ...@@ -69,24 +53,22 @@ def matmul_sp_sm90(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'uint8'), E: T.Tensor((M, K // E_factor), "uint8"),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8")
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
E: {
make_cutlass_metadata_layout( E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E, mma_dtype=in_dtype, arch="9.0", block_k=block_K), E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared: }
make_cutlass_metadata_layout( )
E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
})
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
...@@ -121,7 +103,7 @@ def matmul_sp_sm80( ...@@ -121,7 +103,7 @@ def matmul_sp_sm80(
trans_B, trans_B,
): ):
is_8_bit = "8" in in_dtype is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16' metadata_dtype = "int32" if is_8_bit else "int16"
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K) B_shape = (K, N) if not trans_B else (N, K)
...@@ -132,20 +114,22 @@ def matmul_sp_sm80( ...@@ -132,20 +114,22 @@ def matmul_sp_sm80(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), {
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
}) E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
...@@ -216,7 +200,7 @@ def run_gemm_sp( ...@@ -216,7 +200,7 @@ def run_gemm_sp(
C = _matmul(A, B) C = _matmul(A, B)
if 'float8' in in_dtype: if "float8" in in_dtype:
diff = calc_diff(C_sp, C) diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}" assert diff < 1e-3, f"{diff=}"
else: else:
...@@ -332,15 +316,11 @@ def test_gemm_sp_sm90(): ...@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
True) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True)
False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
True)
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True)
True)
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
...@@ -352,12 +332,9 @@ def test_gemm_sp_sm80(): ...@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True)
True) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
......
...@@ -34,20 +34,22 @@ def matmul( ...@@ -34,20 +34,22 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), {
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
}) E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
...@@ -80,7 +82,7 @@ def run_gemm_ss( ...@@ -80,7 +82,7 @@ def run_gemm_ss(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul( program = matmul(
M, M,
N, N,
...@@ -105,7 +107,8 @@ def run_gemm_ss( ...@@ -105,7 +107,8 @@ def run_gemm_ss(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
...@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): ...@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low, high = (0, 4) if is_unsigned else (-2, 2) low, high = (0, 4) if is_unsigned else (-2, 2)
else: else:
low, high = (0, 128) if is_unsigned else (-64, 64) low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse( A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
M, B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
else: else:
A = randn_semi_sparse( A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A) B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B return A, B
...@@ -184,8 +172,7 @@ def test_gemm_ss(): ...@@ -184,8 +172,7 @@ def test_gemm_ss():
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests # float8 tests
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
2)
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# tfloat32 test # tfloat32 test
...@@ -222,10 +209,10 @@ def matmul_rs( ...@@ -222,10 +209,10 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -233,11 +220,13 @@ def matmul_rs( ...@@ -233,11 +220,13 @@ def matmul_rs(
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
}) E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
...@@ -271,7 +260,7 @@ def run_gemm_rs( ...@@ -271,7 +260,7 @@ def run_gemm_rs(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_rs( program = matmul_rs(
M, M,
N, N,
...@@ -296,7 +285,8 @@ def run_gemm_rs( ...@@ -296,7 +285,8 @@ def run_gemm_rs(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B) C_sp = kernel(A_sparse, E, B)
...@@ -376,10 +366,10 @@ def matmul_sr( ...@@ -376,10 +366,10 @@ def matmul_sr(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -387,11 +377,13 @@ def matmul_sr( ...@@ -387,11 +377,13 @@ def matmul_sr(
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
B_shared: tilelang.layout.make_swizzled_layout(B_shared), {
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
}) E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
...@@ -425,7 +417,7 @@ def run_gemm_sr( ...@@ -425,7 +417,7 @@ def run_gemm_sr(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_sr( program = matmul_sr(
M, M,
N, N,
...@@ -450,7 +442,8 @@ def run_gemm_sr( ...@@ -450,7 +442,8 @@ def run_gemm_sr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B) C_sp = kernel(A_sparse, E, B)
...@@ -531,10 +524,10 @@ def matmul_rr( ...@@ -531,10 +524,10 @@ def matmul_rr(
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype), E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -543,12 +536,14 @@ def matmul_rr( ...@@ -543,12 +536,14 @@ def matmul_rr(
A_frag = T.alloc_fragment(A_frag_shape, in_dtype) A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype) B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype) C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
B_shared: tilelang.layout.make_swizzled_layout(B_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"), E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
}) E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag) T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
...@@ -583,7 +578,7 @@ def run_gemm_rr( ...@@ -583,7 +578,7 @@ def run_gemm_rr(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16' metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_rr( program = matmul_rr(
M, M,
N, N,
...@@ -608,7 +603,8 @@ def run_gemm_rr( ...@@ -608,7 +603,8 @@ def run_gemm_rr(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype) A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0") A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B) C_sp = kernel(A_sparse, E, B)
......
...@@ -11,22 +11,14 @@ def _check(original, transformed): ...@@ -11,22 +11,14 @@ def _check(original, transformed):
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod) mod = tl.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True)
True)
def test_trival_pipeline(): def test_trival_pipeline():
@T.prim_func @T.prim_func
def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial( for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}):
0,
1,
annotations={
"software_pipeline_stage": [0, 1],
"software_pipeline_order": [0, 1]
}):
with T.block(): with T.block():
T.reads(A[tx, i]) T.reads(A[tx, i])
T.writes(C[tx, i]) T.writes(C[tx, i])
......
...@@ -21,10 +21,8 @@ def _check(original, transformed): ...@@ -21,10 +21,8 @@ def _check(original, transformed):
def test_cluster_planning(): def test_cluster_planning():
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
(1024, 1024), "float16")):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16") A_shared = T.alloc_shared((128, 32), "float16")
B_shared = T.alloc_shared((32, 128), "float16") B_shared = T.alloc_shared((32, 128), "float16")
...@@ -41,8 +39,7 @@ def test_cluster_planning(): ...@@ -41,8 +39,7 @@ def test_cluster_planning():
T.copy(C_local, C[by * 128, bx * 128]) T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor( def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
(1024, 1024), "float16")):
T.func_attr({"clusterIdx.y": T.int32(2)}) T.func_attr({"clusterIdx.y": T.int32(2)})
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16") A_shared = T.alloc_shared((128, 32), "float16")
......
...@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64 block_N = 64
num_stages = 0 num_stages = 0
threads = 128 threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
batch = T.int32(batch) batch = T.int32(batch)
heads = T.int32(heads) heads = T.int32(heads)
...@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
@T.macro @T.macro
def MMA0( def MMA0(
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
...@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def MMA1( def MMA1(
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
V_shared: T.Tensor([block_M, dim], dtype), V_shared: T.Tensor([block_M, dim], dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype), acc_s_cast: T.Tensor([block_M, block_N], dtype),
acc_o: T.Tensor([block_M, dim], accum_dtype), acc_o: T.Tensor([block_M, dim], accum_dtype),
k: T.int32, k: T.int32,
by: T.int32, by: T.int32,
bz: T.int32, bz: T.int32,
): ):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.Tensor([block_M, block_N], accum_dtype), acc_s: T.Tensor([block_M, block_N], accum_dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype), acc_s_cast: T.Tensor([block_M, block_N], dtype),
scores_max: T.Tensor([block_M], accum_dtype), scores_max: T.Tensor([block_M], accum_dtype),
scores_max_prev: T.Tensor([block_M], accum_dtype), scores_max_prev: T.Tensor([block_M], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype), scores_scale: T.Tensor([block_M], accum_dtype),
scores_sum: T.Tensor([block_M], accum_dtype), scores_sum: T.Tensor([block_M], accum_dtype),
logsum: T.Tensor([block_M], accum_dtype), logsum: T.Tensor([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.Tensor([block_M, dim], accum_dtype), acc_o: T.Tensor([block_M, dim], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype), scores_scale: T.Tensor([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(shape, dtype), Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype), BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype), Output: T.Tensor(shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype) block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj] block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = ( loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv( T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) )
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
scores_sum, logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main return main
......
...@@ -22,7 +22,6 @@ def _check(original, transformed): ...@@ -22,7 +22,6 @@ def _check(original, transformed):
def test_lower_fence_proxy(): def test_lower_fence_proxy():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
...@@ -30,12 +29,15 @@ def test_lower_fence_proxy(): ...@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16): for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), T.call_intrin(
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", "handle",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), tir.op.Op.get("tl.tl_gemm"),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
@T.prim_func @T.prim_func
def after(): def after():
...@@ -44,19 +46,21 @@ def test_lower_fence_proxy(): ...@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16): for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2) C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.fence_proxy_async() T.fence_proxy_async()
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"), T.call_intrin(
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", "handle",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), tir.op.Op.get("tl.tl_gemm"),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
_check(before, after) _check(before, after)
def test_async_to_generic_no_double_fence(): def test_async_to_generic_no_double_fence():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
...@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence(): ...@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
def test_proxy_hint_override(): def test_proxy_hint_override():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
...@@ -123,7 +126,6 @@ def test_proxy_hint_override(): ...@@ -123,7 +126,6 @@ def test_proxy_hint_override():
def test_tma_store_sync_injection(): def test_tma_store_sync_injection():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
...@@ -154,7 +156,6 @@ def test_tma_store_sync_injection(): ...@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
def test_wgmma_marked_async(): def test_wgmma_marked_async():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(1): with T.Kernel(1):
...@@ -164,9 +165,24 @@ def test_wgmma_marked_async(): ...@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
C_local = T.decl_buffer((32,), "float16", scope="local") C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0) A_shared[0] = T.float16(0)
T.warpgroup_arrive() T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16", T.ptx_wgmma_ss(
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data, "float16",
T.int32(0), T.bool(True), 1, 1) "m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc_a.data,
T.int32(0),
desc_b.data,
T.int32(0),
C_local.data,
T.int32(0),
T.bool(True),
1,
1,
)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod) mod = tvm.tir.transform.BindTarget(auto_target)(mod)
......
...@@ -35,26 +35,25 @@ def test_inject_set_max_nreg(): ...@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1)) T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0: if v - 128 == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0, 2, 2, 0), T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), k * 32,
k * 32, by * 64) by * 64,
T.evaluate( )
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)])) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else: else:
# Consumer branch - should have set_max_nreg(240, 1) # Consumer branch - should have set_max_nreg(240, 1)
for k in range(16): for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2) T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle",
T.tvm_access_ptr( "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) )
T.evaluate( T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
# Apply the InjectSetMaxNReg pass # Apply the InjectSetMaxNReg pass
func = before func = before
...@@ -67,15 +66,18 @@ def test_inject_set_max_nreg(): ...@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
set_max_nreg_calls = [] set_max_nreg_calls = []
def collect_set_max_nreg(stmt): def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and if (
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): isinstance(stmt, tvm.tir.Evaluate)
and hasattr(stmt.value, "op")
and hasattr(stmt.value.op, "name")
and stmt.value.op.name == "tl.set_max_nreg"
):
set_max_nreg_calls.append(stmt.value) set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer) # We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert len(set_max_nreg_calls assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg test passed!") print("InjectSetMaxNReg test passed!")
...@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg(): ...@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
set_max_nreg_calls = [] set_max_nreg_calls = []
def collect_set_max_nreg(stmt): def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and if (
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"): isinstance(stmt, tvm.tir.Evaluate)
and hasattr(stmt.value, "op")
and hasattr(stmt.value.op, "name")
and stmt.value.op.name == "tl.set_max_nreg"
):
set_max_nreg_calls.append(stmt.value) set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg) tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# Should have no set_max_nreg calls when no_set_max_nreg is present # Should have no set_max_nreg calls when no_set_max_nreg is present
assert len( assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
set_max_nreg_calls
) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg with no_set_max_nreg test passed!") print("InjectSetMaxNReg with no_set_max_nreg test passed!")
......
...@@ -8,17 +8,21 @@ import pytest ...@@ -8,17 +8,21 @@ import pytest
auto_target = tvm.target.Target(determine_target("auto")) auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ @pytest.mark.parametrize(
(64, 64, 32, 128, 8, "float16"), "block_M, block_N, block_K, threads, vec_load_b, dtype",
]) [
(64, 64, 32, 128, 8, "float16"),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
def before(): def before():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
...@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
t = thread_bindings t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
for vec in T.Parallel(vec_load_b): for vec in T.Parallel(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t // B_shared[
(block_N // vec_load_b), t % (block_N // vec_load_b) * i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) + vec] = T.if_then_else( t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
k * block_K + i * (threads * vec_load_b // block_N) + t // ] = T.if_then_else(
(block_N // vec_load_b) < K and bx * block_N + t % k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
(block_N // vec_load_b) * (block_N // vec_load_b) < N, and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) + B[
t // (block_N // vec_load_b), bx * block_N + t % k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) * (block_N // vec_load_b) + vec], bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
T.float16(0)) ],
T.float16(0),
)
return tvm.IRModule({'main': main}) return tvm.IRModule({"main": main})
def after(): def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
t = thread_bindings t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
if (k * block_K + i * (threads * vec_load_b // block_N) + t // if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0:
(block_N // vec_load_b)) * N % vec_load_b == 0:
for vec in T.vectorized(vec_load_b): for vec in T.vectorized(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t // B_shared[
(block_N // vec_load_b), t % (block_N // vec_load_b) * i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) + vec] = T.if_then_else( t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
k * block_K + i * ] = T.if_then_else(
(threads * vec_load_b // block_N) + t // k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
(block_N // vec_load_b) < K and bx * block_N + t % and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
(block_N // vec_load_b) * (block_N // vec_load_b) < N, B[
B[k * block_K + i * (threads * vec_load_b // block_N) + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
bx * block_N + t % (block_N // vec_load_b) * ],
(block_N // vec_load_b) + vec], T.float16(0)) T.float16(0),
)
else: else:
for vec in T.serial(vec_load_b): for vec in T.serial(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t // B_shared[
(block_N // vec_load_b), t % (block_N // vec_load_b) * i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) + vec] = T.if_then_else( t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
k * block_K + i * ] = T.if_then_else(
(threads * vec_load_b // block_N) + t // k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
(block_N // vec_load_b) < K and bx * block_N + t % and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
(block_N // vec_load_b) * (block_N // vec_load_b) < N, B[
B[k * block_K + i * (threads * vec_load_b // block_N) + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
bx * block_N + t % (block_N // vec_load_b) * ],
(block_N // vec_load_b) + vec], T.float16(0)) T.float16(0),
)
return tvm.IRModule({'main': main}) return tvm.IRModule({"main": main})
with tvm.target.Target(auto_target): with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(before()) mod = tvm.tir.transform.BindTarget(auto_target)(before())
......
...@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off ...@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype = "float32" dtype = "float32"
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),): def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype) A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off ...@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared[tid, j] = A[tid + M_offset, j + N_offset] A_shared[tid, j] = A[tid + M_offset, j + N_offset]
@T.prim_func @T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),): def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype) A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset]) T.reads(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N): for j in T.serial(N):
A_shared[tid, j] = T.if_then_else( A_shared[tid, j] = T.if_then_else(
j + N_offset < N, j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0)
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], )
T.float32(0)), T.float32(0))
return main, expected return main, expected
...@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64): ...@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel(): def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check # NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.dynamic('num_tokens') num_tokens = T.dynamic("num_tokens")
num_threads = 128 num_threads = 128
@T.prim_func @T.prim_func
def main(x: T.Tensor((num_tokens,), dtype="int64")): def main(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int') count = T.alloc_var("int")
thread_idx = T.get_thread_binding() thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads idx = thread_idx + i * num_threads
...@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel(): ...@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
@T.prim_func @T.prim_func
def expected(x: T.Tensor((num_tokens,), dtype="int64")): def expected(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int') count = T.alloc_var("int")
thread_idx = T.get_thread_binding() thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)): for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads idx = thread_idx + i * num_threads
count += T.Cast("int32", count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
return main, expected return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64, def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
N: int = 64,
M_offset: int = 2,
N_offset: int = 2):
dtype = "float32" dtype = "float32"
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),): def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype) A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64, ...@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
T.atomic_add(A[tid + M_offset, j + N_offset], 1) T.atomic_add(A[tid + M_offset, j + N_offset], 1)
@T.prim_func @T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),): def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype) A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset]) T.reads(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N): for j in T.serial(N):
A_shared[tid, j] = T.if_then_else( A_shared[tid, j] = T.if_then_else(
j + N_offset < N, j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0)
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], )
T.float32(0)), T.float32(0))
# Nest if-then-else is expected, do not flatten it to pass structural equal check # Nest if-then-else is expected, do not flatten it to pass structural equal check
if j + N_offset < N: # noqa: SIM102 if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M: if tid + M_offset < M:
...@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in ...@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
dtype = "float32" dtype = "float32"
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),): def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding() tid = T.get_thread_binding()
for j in T.serial(N): for j in T.serial(N):
A[tid + M_offset, j + N_offset] = 1 A[tid + M_offset, j + N_offset] = 1
@T.prim_func @T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),): def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding() tid = T.get_thread_binding()
T.writes(A[tid + M_offset, N_offset:N + N_offset]) T.writes(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N): for j in T.serial(N):
if j + N_offset < N: # noqa: SIM102 if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M: if tid + M_offset < M:
......
...@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): ...@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len = 8 vec_len = 8
@T.prim_func @T.prim_func
def main(A: T.Tensor((M, N, vec_len), dtype="float32"),): def main(
A: T.Tensor((M, N, vec_len), dtype="float32"),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): ...@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared[tid, j, v] = A[tid, j, v] A_shared[tid, j, v] = A[tid, j, v]
@T.prim_func @T.prim_func
def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),): def expected(
A: T.Tensor((M, N, vec_len), dtype="float32"),
):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding() tid = T.get_thread_binding()
......
...@@ -8,12 +8,10 @@ def _check(original, transformed): ...@@ -8,12 +8,10 @@ def _check(original, transformed):
func = original func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.LetInline()(mod) mod = tl.transform.LetInline()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True)
True)
def test_let_binding(): def test_let_binding():
@T.prim_func @T.prim_func
def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
for i in range(128): for i in range(128):
...@@ -34,7 +32,6 @@ def test_let_binding(): ...@@ -34,7 +32,6 @@ def test_let_binding():
def test_parallel_scope(): def test_parallel_scope():
@T.prim_func @T.prim_func
def before(A: T.Tensor((128,), "float32")): def before(A: T.Tensor((128,), "float32")):
for i in T.Parallel(128): for i in T.Parallel(128):
......
...@@ -24,7 +24,6 @@ def _check(original, transformed): ...@@ -24,7 +24,6 @@ def _check(original, transformed):
def test_lower_hopper_intrin_barrier(): def test_lower_hopper_intrin_barrier():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
...@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier(): ...@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
v_1 = T.launch_thread("threadIdx.x", 128) v_1 = T.launch_thread("threadIdx.x", 128)
T.evaluate(tir.Call("handle", "tir.create_barriers", [4])) T.evaluate(tir.Call("handle", "tir.create_barriers", [4]))
with T.If(v_1 == 0), T.Then(): with T.If(v_1 == 0), T.Then():
T.evaluate( T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128]))
tir.Call("handle", "tir.ptx_init_barrier_thread_count", T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128]))
[T.get_mbarrier(0), 128])) T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128]))
T.evaluate( T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128]))
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(1), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(2), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(3), 128]))
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"])) T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
_check(before, after) _check(before, after)
......
...@@ -8,63 +8,69 @@ import pytest ...@@ -8,63 +8,69 @@ import pytest
auto_target = tvm.target.Target(determine_target("auto")) auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [ @pytest.mark.parametrize(
(64, 64, 32, 128, 8, "float16"), "block_M, block_N, block_K, threads, vec_load_b, dtype",
]) [
(64, 64, 32, 128, 8, "float16"),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n") N = tvm.te.var("n")
K = tvm.te.var("k") K = tvm.te.var("k")
def before(): def before():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
return tvm.IRModule({'main': main}) return tvm.IRModule({"main": main})
def after(): def after():
@T.prim_func @T.prim_func
def main(B: T.Tensor((K, N), dtype),): def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
t = thread_bindings t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)): for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
if (k * block_K + i * (threads * vec_load_b // block_N) + t // if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0:
(block_N // vec_load_b)) * N % vec_load_b == 0:
for vec in T.vectorized(vec_load_b): for vec in T.vectorized(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t // B_shared[
(block_N // vec_load_b), t % (block_N // vec_load_b) * i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) + vec] = T.if_then_else( t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
k * block_K + i * ] = T.if_then_else(
(threads * vec_load_b // block_N) + t // k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
(block_N // vec_load_b) < K and bx * block_N + t % and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
(block_N // vec_load_b) * (block_N // vec_load_b) < N, B[
B[k * block_K + i * (threads * vec_load_b // block_N) + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
bx * block_N + t % (block_N // vec_load_b) * ],
(block_N // vec_load_b) + vec], T.float16(0)) T.float16(0),
)
else: else:
for vec in T.serial(vec_load_b): for vec in T.serial(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t // B_shared[
(block_N // vec_load_b), t % (block_N // vec_load_b) * i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
(block_N // vec_load_b) + vec] = T.if_then_else( t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
k * block_K + i * ] = T.if_then_else(
(threads * vec_load_b // block_N) + t // k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
(block_N // vec_load_b) < K and bx * block_N + t % and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
(block_N // vec_load_b) * (block_N // vec_load_b) < N, B[
B[k * block_K + i * (threads * vec_load_b // block_N) + k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t // (block_N // vec_load_b), bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
bx * block_N + t % (block_N // vec_load_b) * ],
(block_N // vec_load_b) + vec], T.float16(0)) T.float16(0),
)
return tvm.IRModule({'main': main}) return tvm.IRModule({"main": main})
with tvm.transform.PassContext(): with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(before()) mod = tvm.tir.transform.BindTarget(auto_target)(before())
......
...@@ -80,7 +80,6 @@ def test_target_host_removed(): ...@@ -80,7 +80,6 @@ def test_target_host_removed():
@I.ir_module @I.ir_module
class before: class before:
@T.prim_func @T.prim_func
def main(A: T.Buffer(1, "float32")): def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
...@@ -102,7 +101,6 @@ def test_internal_subroutine_call(): ...@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
@I.ir_module @I.ir_module
class before: class before:
@T.prim_func @T.prim_func
def main(A: T.Buffer(1, "float32")): def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm", host="llvm")}) T.func_attr({"target": T.target("llvm", host="llvm")})
...@@ -121,7 +119,8 @@ def test_internal_subroutine_call(): ...@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
subroutine_call_op = compute_scope.body.value.op subroutine_call_op = compute_scope.body.value.op
assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), ( assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), (
f"The main function's CallNode should use the subroutine's GLobalVar as the operation, " f"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f"but instead has an operation of type {subroutine_call_op}") f"but instead has an operation of type {subroutine_call_op}"
)
def test_subroutine_call_to_externally_visible_subroutine(): def test_subroutine_call_to_externally_visible_subroutine():
...@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine(): ...@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
@I.ir_module @I.ir_module
class before: class before:
@T.prim_func @T.prim_func
def main(A: T.Buffer(1, "float32")): def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")}) T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
...@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine(): ...@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
assert subroutine_compute_scope is not None assert subroutine_compute_scope is not None
subroutine_call_op = main_compute_scope.body.value.op subroutine_call_op = main_compute_scope.body.value.op
assert ( assert isinstance(subroutine_call_op, tvm.ir.Op) and subroutine_call_op.name == "tir.tvm_call_cpacked", (
isinstance(subroutine_call_op, tvm.ir.Op) and f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
subroutine_call_op.name == "tir.tvm_call_cpacked" f"but instead has an operation of type {subroutine_call_op}"
), (f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " )
f"but instead has an operation of type {subroutine_call_op}")
@tilelang.testing.requires_llvm @tilelang.testing.requires_llvm
...@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count(): ...@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count():
@T.prim_func @T.prim_func
def func( def func(
A: T.Buffer([16, 16], "int32"), A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"), C: T.Buffer([16, 16], "int32"),
D: T.Buffer([16, 16], "int32"), D: T.Buffer([16, 16], "int32"),
): ):
pass pass
......
...@@ -31,7 +31,6 @@ block_K = 32 ...@@ -31,7 +31,6 @@ block_K = 32
def test_multi_version_buffer(): def test_multi_version_buffer():
@T.prim_func @T.prim_func
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8) bx = T.launch_thread("blockIdx.x", 8)
...@@ -49,21 +48,27 @@ def test_multi_version_buffer(): ...@@ -49,21 +48,27 @@ def test_multi_version_buffer():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}): for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
k * 32, by * 64) k * 32,
by * 64,
)
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
bx * 64, k * 32) bx * 64,
k * 32,
)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
@T.prim_func @T.prim_func
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)): def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
...@@ -82,31 +87,32 @@ def test_multi_version_buffer(): ...@@ -82,31 +87,32 @@ def test_multi_version_buffer():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}): for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), k * 32,
k * 32, by * 64) by * 64,
)
if v == 0: if v == 0:
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
2, 0), 0, 0,
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), bx * 64,
bx * 64, k * 32) k * 32,
)
T.call_extern( T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "handle",
T.tvm_access_ptr( "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr( T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3)) )
_check(before, after) _check(before, after)
def test_multi_version_buffer_with_let(): def test_multi_version_buffer_with_let():
@T.prim_func @T.prim_func
def before(scales: T.Tensor((4,), "float32")): def before(scales: T.Tensor((4,), "float32")):
with T.block("root"): with T.block("root"):
......
...@@ -19,10 +19,8 @@ def _check(original, transformed): ...@@ -19,10 +19,8 @@ def _check(original, transformed):
def test_simple_pipeline(): def test_simple_pipeline():
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
(1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32") A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32") B_shared = T.alloc_shared((32, 128), "float32")
...@@ -39,8 +37,7 @@ def test_simple_pipeline(): ...@@ -39,8 +37,7 @@ def test_simple_pipeline():
T.copy(C_local, C[by * 128, bx * 128]) T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor( def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
(1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32") A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32") B_shared = T.alloc_shared((32, 128), "float32")
...@@ -49,14 +46,13 @@ def test_simple_pipeline(): ...@@ -49,14 +46,13 @@ def test_simple_pipeline():
T.clear(C_local) T.clear(C_local)
for ko in T.serial( for ko in T.serial(
32, 32,
annotations={ annotations={
"software_pipeline_async_stages": [T.int32(0)], "software_pipeline_async_stages": [T.int32(0)],
"software_pipeline_order": [T.int32(0), T.int32(1), "software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)],
T.int32(2)], "software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)],
"software_pipeline_stage": [T.int32(3), T.int32(3), },
T.int32(3)] ):
}):
T.copy(A[by * 128, ko * 32], A_shared) T.copy(A[by * 128, ko * 32], A_shared)
T.copy(B[ko * 32, bx * 128], B_shared) T.copy(B[ko * 32, bx * 128], B_shared)
T.gemm(A_shared, B_shared, C_local) T.gemm(A_shared, B_shared, C_local)
......
...@@ -8,14 +8,13 @@ def modify( ...@@ -8,14 +8,13 @@ def modify(
with_B: bool = False, with_B: bool = False,
with_bias: bool = False, with_bias: bool = False,
): ):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((64, 64)), A: T.Tensor((64, 64)),
B: T.Tensor((64, 64)), B: T.Tensor((64, 64)),
C: T.Tensor((64, 64)), C: T.Tensor((64, 64)),
D: T.Tensor((64, 64)), D: T.Tensor((64, 64)),
bias: T.Tensor((64, 64)), bias: T.Tensor((64, 64)),
): ):
if with_B: if with_B:
if with_bias: if with_bias:
...@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False): ...@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
a: T.handle, a: T.handle,
...@@ -76,6 +74,7 @@ def test_matmul(): ...@@ -76,6 +74,7 @@ def test_matmul():
kernel = tl.compile(mod["main"], out_idx=[2]) kernel = tl.compile(mod["main"], out_idx=[2])
import torch import torch
a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
c = kernel(a, b) c = kernel(a, b)
......
...@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc): ...@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
cuda_target = tvm.target.Target("cuda", host="llvm") cuda_target = tvm.target.Target("cuda", host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}))(mod)
"global_symbol": "test",
"target": cuda_target
}))(
mod)
mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod) mod = tvm.tir.transform.SplitHostDevice()(mod)
...@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc): ...@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_if_with_same_index(): def test_sync_if_with_same_index():
@T.prim_func(check_well_formed=False) @T.prim_func(check_well_formed=False)
def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x") threadIdx_x = T.env_thread("threadIdx.x")
...@@ -47,7 +42,6 @@ def test_sync_if_with_same_index(): ...@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_read_thread_id_independent_location(): def test_sync_read_thread_id_independent_location():
@T.prim_func @T.prim_func
def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x") threadIdx_x = T.env_thread("threadIdx.x")
...@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location(): ...@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_shared(): def test_sync_shared():
@T.prim_func(private=True) @T.prim_func(private=True)
def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 1) blockIdx_x = T.launch_thread("blockIdx.x", 1)
...@@ -113,7 +106,6 @@ def test_sync_shared(): ...@@ -113,7 +106,6 @@ def test_sync_shared():
@tvm.testing.requires_cuda @tvm.testing.requires_cuda
def test_sync_let_stmt(): def test_sync_let_stmt():
@T.prim_func(private=True) @T.prim_func(private=True)
def func(A: T.Buffer((16 * 512), "float32")): def func(A: T.Buffer((16 * 512), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 16) blockIdx_x = T.launch_thread("blockIdx.x", 16)
...@@ -136,9 +128,9 @@ def test_sync_let_stmt(): ...@@ -136,9 +128,9 @@ def test_sync_let_stmt():
in_thread_A_temp_1[0] = A_temp in_thread_A_temp_1[0] = A_temp
cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local") cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local")
with T.attr( with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope", "reduce_scope",
T.reinterpret("handle", T.uint64(0)), T.reinterpret("handle", T.uint64(0)),
): ):
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
T.uint32(1), T.uint32(1),
...@@ -190,16 +182,19 @@ def test_sync_let_stmt(): ...@@ -190,16 +182,19 @@ def test_sync_let_stmt():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_sync_shared_dyn_stmatrix_loop_hoist(): def test_sync_shared_dyn_stmatrix_loop_hoist():
@T.prim_func @T.prim_func
def func(): def func():
buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn") buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn")
tx = T.launch_thread("threadIdx.x", 384) tx = T.launch_thread("threadIdx.x", 384)
for i in T.unroll(8): for i in T.unroll(8):
off = ( off = (
i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 + i // 4 * 8192
(tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 + + tx // 32 * 1024
(tx % 32 // 16 + tx % 2) % 2 * 8) + tx % 16 * 64
+ (tx % 8 // 4 + i % 4 // 2) % 2 * 32
+ (tx % 4 // 2 + i % 2) % 2 * 16
+ (tx % 32 // 16 + tx % 2) % 2 * 8
)
T.evaluate( T.evaluate(
T.call_intrin( T.call_intrin(
"handle", "handle",
...@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist(): ...@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
2, 2,
), ),
T.int32(2), T.int32(2),
)) )
)
mod = tvm.IRModule({"main": func}) mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment