Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type ...@@ -7,22 +7,13 @@ from tilelang.utils import map_torch_type
@tl.jit @tl.jit
def tensor_null_test(M, def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False):
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float",
with_bias=False):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
Bias: T.Tensor((N), accum_dtype), Bias: T.Tensor((N), accum_dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -48,12 +39,10 @@ def tensor_null_test(M, ...@@ -48,12 +39,10 @@ def tensor_null_test(M,
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
kernel = tensor_null_test( kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
kernel(a, b, c, None) kernel(a, b, c, None)
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -136,9 +136,9 @@ def matmu_jit_kernel( ...@@ -136,9 +136,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -206,6 +206,7 @@ def run_gemm_jit_kernel( ...@@ -206,6 +206,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B): def ref_program(A, B):
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype) C = C.to(out_dtype)
return C return C
...@@ -233,19 +234,9 @@ def test_gemm_jit_kernel(): ...@@ -233,19 +234,9 @@ def test_gemm_jit_kernel():
) )
def run_nvrtc_kernel_do_bench(M, def run_nvrtc_kernel_do_bench(
N, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
K, ):
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M, ...@@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M,
def test_nvrtc_kernel_do_bench(): def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
256, 32, 2)
def run_nvrtc_kernel_multi_stream(
def run_nvrtc_kernel_multi_stream(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M, ...@@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M,
def test_nvrtc_kernel_multi_stream(): def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
128, 256, 32, 2)
def run_nvrtc_dynamic_shape(
def run_nvrtc_dynamic_shape(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M, ...@@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c) matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_dynamic_shape(): def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape( run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape( run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_nvrtc_dynamic_shape( run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
def check_hopper(): def check_hopper():
...@@ -412,35 +375,18 @@ def check_hopper(): ...@@ -412,35 +375,18 @@ def check_hopper():
return compute_capability == (9, 0) return compute_capability == (9, 0)
def convolution_im2col(N, def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -449,11 +395,13 @@ def convolution_im2col(N, ...@@ -449,11 +395,13 @@ def convolution_im2col(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
data_shared: tilelang.layout.make_swizzled_layout(data_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), data_shared: tilelang.layout.make_swizzled_layout(data_shared),
}) kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -467,23 +415,9 @@ def convolution_im2col(N, ...@@ -467,23 +415,9 @@ def convolution_im2col(N,
return main return main
def run_nvrtc_im2col_tma_desc(N, def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
"""Test im2col TMA descriptor functionality in NVRTC backend.""" """Test im2col TMA descriptor functionality in NVRTC backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads)
num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc") conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
...@@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N, ...@@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N,
return C return C
ref_c = ref_program(a, b) ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_im2col_tma_desc(): def test_nvrtc_im2col_tma_desc():
"""Test im2col TMA descriptor with NVRTC backend.""" """Test im2col TMA descriptor with NVRTC backend."""
if not check_hopper(): if not check_hopper():
import pytest import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)") pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor # Small test case for im2col TMA descriptor
run_nvrtc_im2col_tma_desc( run_nvrtc_im2col_tma_desc(
N=4, N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256
C=64, )
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
def test_nvrtc_l2_persistent_map(): def test_nvrtc_l2_persistent_map():
...@@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map(): ...@@ -543,12 +465,11 @@ def test_nvrtc_l2_persistent_map():
block_size=256, block_size=256,
dtype="float32", dtype="float32",
): ):
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(M * N // block_size, threads=block_size) as bx: with T.Kernel(M * N // block_size, threads=block_size) as bx:
# Annotate L2 persistent cache for buffer B # Annotate L2 persistent cache for buffer B
......
...@@ -16,9 +16,9 @@ def matmul_kernel_jit( ...@@ -16,9 +16,9 @@ def matmul_kernel_jit(
block_K, block_K,
trans_A=False, trans_A=False,
trans_B=True, trans_B=True,
in_dtype='float16', in_dtype="float16",
out_dtype='float32', out_dtype="float32",
accum_dtype='float32', accum_dtype="float32",
num_stages=2, num_stages=2,
threads=128, threads=128,
): ):
...@@ -31,9 +31,9 @@ def matmul_kernel_jit( ...@@ -31,9 +31,9 @@ def matmul_kernel_jit(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
...@@ -28,9 +28,9 @@ def matmul( ...@@ -28,9 +28,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -74,9 +74,9 @@ def matmu_jit_kernel( ...@@ -74,9 +74,9 @@ def matmu_jit_kernel(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -144,6 +144,7 @@ def run_gemm_jit_kernel( ...@@ -144,6 +144,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B): def ref_program(A, B):
import torch import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype) C = C.to(out_dtype)
return C return C
...@@ -171,19 +172,9 @@ def test_gemm_jit_kernel(): ...@@ -171,19 +172,9 @@ def test_gemm_jit_kernel():
) )
def run_tvm_ffi_kernel_do_bench(M, def run_tvm_ffi_kernel_do_bench(
N, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
K, ):
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M, ...@@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M,
def test_tvm_ffi_kernel_do_bench(): def test_tvm_ffi_kernel_do_bench():
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
256, 32, 2)
def run_tvm_ffi_kernel_multi_stream(
def run_tvm_ffi_kernel_multi_stream(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M, ...@@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M,
def test_tvm_ffi_kernel_multi_stream(): def test_tvm_ffi_kernel_multi_stream():
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
128, 256, 32, 2)
def run_tvm_ffi_dynamic_shape(
def run_tvm_ffi_dynamic_shape(M, M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
N, ):
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
program = matmul( program = matmul(
M, M,
N, N,
...@@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M, ...@@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c) matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype) tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_tvm_ffi_dynamic_shape(): def test_tvm_ffi_dynamic_shape():
run_tvm_ffi_dynamic_shape( run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape( run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_tvm_ffi_dynamic_shape( run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2
"float16", 128, 256, 32, 2) )
def check_hopper(): def check_hopper():
...@@ -350,35 +315,18 @@ def check_hopper(): ...@@ -350,35 +315,18 @@ def check_hopper():
return compute_capability == (9, 0) return compute_capability == (9, 0)
def convolution_im2col(N, def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func @T.prim_func
def main( def main(
data: T.Tensor((N, H, W, C), dtype), data: T.Tensor((N, H, W, C), dtype),
kernel: T.Tensor((KH, KW, C, F), dtype), kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype), out: T.Tensor((N, OH, OW, F), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype) data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype) kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype) out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -387,11 +335,13 @@ def convolution_im2col(N, ...@@ -387,11 +335,13 @@ def convolution_im2col(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({ T.annotate_layout(
out_shared: tilelang.layout.make_swizzled_layout(out_shared), {
data_shared: tilelang.layout.make_swizzled_layout(data_shared), out_shared: tilelang.layout.make_swizzled_layout(out_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), data_shared: tilelang.layout.make_swizzled_layout(data_shared),
}) kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
}
)
T.clear(out_local) T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
...@@ -405,23 +355,9 @@ def convolution_im2col(N, ...@@ -405,23 +355,9 @@ def convolution_im2col(N,
return main return main
def run_tvm_ffi_im2col_tma_desc(N, def run_tvm_ffi_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256):
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
"""Test im2col TMA descriptor functionality in tvm_ffi backend.""" """Test im2col TMA descriptor functionality in tvm_ffi backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads)
num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi") conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
...@@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N, ...@@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N,
return C return C
ref_c = ref_program(a, b) ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close( tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_tvm_ffi_im2col_tma_desc(): def test_tvm_ffi_im2col_tma_desc():
"""Test im2col TMA descriptor with tvm_ffi backend.""" """Test im2col TMA descriptor with tvm_ffi backend."""
if not check_hopper(): if not check_hopper():
import pytest import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)") pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor # Small test case for im2col TMA descriptor
run_tvm_ffi_im2col_tma_desc( run_tvm_ffi_im2col_tma_desc(
N=4, N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256
C=64, )
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
def test_tvm_ffi_l2_persistent_map(): def test_tvm_ffi_l2_persistent_map():
...@@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map(): ...@@ -481,12 +405,11 @@ def test_tvm_ffi_l2_persistent_map():
block_size=256, block_size=256,
dtype="float32", dtype="float32",
): ):
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(M * N // block_size, threads=block_size) as bx: with T.Kernel(M * N // block_size, threads=block_size) as bx:
# Annotate L2 persistent cache for buffer B # Annotate L2 persistent cache for buffer B
...@@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map(): ...@@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map():
kernel = elementwise_add_with_l2_cache(M, N) kernel = elementwise_add_with_l2_cache(M, N)
source = kernel.get_host_source() source = kernel.get_host_source()
assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, (
assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
)
assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, (
"Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
)
# Create test tensors # Create test tensors
a = torch.randn(M, N, dtype=torch.float32).cuda() a = torch.randn(M, N, dtype=torch.float32).cuda()
......
...@@ -6,7 +6,8 @@ from tvm import DataType ...@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -111,12 +112,11 @@ def tl_matmul( ...@@ -111,12 +112,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -124,10 +124,12 @@ def tl_matmul( ...@@ -124,10 +124,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -135,7 +137,6 @@ def tl_matmul( ...@@ -135,7 +137,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -145,7 +146,6 @@ def tl_matmul( ...@@ -145,7 +146,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
......
...@@ -16,15 +16,15 @@ def elementwise_add( ...@@ -16,15 +16,15 @@ def elementwise_add(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), in_dtype), A: T.Tensor((M, N), in_dtype),
B: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N start_x = bx * block_N
start_y = by * block_M start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N): for local_y, local_x in T.Parallel(block_M, block_N):
y = start_y + local_y y = start_y + local_y
x = start_x + local_x x = start_x + local_x
......
...@@ -12,12 +12,11 @@ def calc_diff(x, y): ...@@ -12,12 +12,11 @@ def calc_diff(x, y):
def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype): def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), in_dtype), A: T.Tensor((M, K), in_dtype),
B: T.Tensor((N, K), in_dtype), B: T.Tensor((N, K), in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by):
A_shared = T.alloc_shared((bM, bK), in_dtype) A_shared = T.alloc_shared((bM, bK), in_dtype)
...@@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ ...@@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
C = kernel(A, B) C = kernel(A, B)
ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype))
B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype))
print(C) print(C)
print(ref_c) print(ref_c)
diff = calc_diff(C, ref_c) diff = calc_diff(C, ref_c)
......
...@@ -6,7 +6,8 @@ from tvm import DataType ...@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -110,12 +111,11 @@ def tl_matmul( ...@@ -110,12 +111,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -123,10 +123,12 @@ def tl_matmul( ...@@ -123,10 +123,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -134,7 +136,6 @@ def tl_matmul( ...@@ -134,7 +136,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -144,7 +145,6 @@ def tl_matmul( ...@@ -144,7 +145,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
......
...@@ -27,8 +27,8 @@ def gemv_simt( ...@@ -27,8 +27,8 @@ def gemv_simt(
): ):
assert n_partition is not None, "n_partition must be provided" assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, ( assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
"sch_outer_reduction_with_config is not implemented") )
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
...@@ -50,16 +50,15 @@ def gemv_simt( ...@@ -50,16 +50,15 @@ def gemv_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
Bias: T.Tensor(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( bx,
bx, by,
by, ):
):
A_local = T.alloc_local((micro_size_k,), in_dtype) A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype) accum_res = T.alloc_local((1,), accum_dtype)
...@@ -88,13 +87,12 @@ def gemv_simt( ...@@ -88,13 +87,12 @@ def gemv_simt(
) )
else: else:
for ki in T.serial(micro_size_k): for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype)
accum_dtype)
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -104,11 +102,11 @@ def gemv_simt( ...@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res[0], reduced_accum_res[0],
kr, kr,
dtype="handle", dtype="handle",
)) )
)
if kr == 0: if kr == 0:
if with_bias: if with_bias:
C[by, C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else: else:
C[by, bx * n_partition + ni] = reduced_accum_res[0] C[by, bx * n_partition + ni] = reduced_accum_res[0]
......
...@@ -26,9 +26,9 @@ def matmul( ...@@ -26,9 +26,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -95,8 +95,8 @@ def run_gemm( ...@@ -95,8 +95,8 @@ def run_gemm(
if in_dtype == "float32": if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -321,9 +321,9 @@ def matmul_sr( ...@@ -321,9 +321,9 @@ def matmul_sr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -441,9 +441,9 @@ def matmul_rs( ...@@ -441,9 +441,9 @@ def matmul_rs(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared") A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
......
...@@ -6,7 +6,8 @@ from tvm import DataType ...@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type from tilelang.utils.tensor import map_torch_type
...@@ -111,12 +112,11 @@ def tl_matmul( ...@@ -111,12 +112,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -124,10 +124,12 @@ def tl_matmul( ...@@ -124,10 +124,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -135,7 +137,6 @@ def tl_matmul( ...@@ -135,7 +137,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -145,7 +146,6 @@ def tl_matmul( ...@@ -145,7 +146,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
......
...@@ -76,12 +76,11 @@ def tl_matmul_simt( ...@@ -76,12 +76,11 @@ def tl_matmul_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
...@@ -97,7 +96,6 @@ def tl_matmul_simt( ...@@ -97,7 +96,6 @@ def tl_matmul_simt(
T.clear(C_local) T.clear(C_local)
for ko in T.serial(K // block_K): for ko in T.serial(K // block_K):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -109,29 +107,24 @@ def tl_matmul_simt( ...@@ -109,29 +107,24 @@ def tl_matmul_simt(
for ki in T.serial((block_K // micro_size_k)): for ki in T.serial((block_K // micro_size_k)):
for i in T.serial(local_size_a): for i in T.serial(local_size_a):
for mk in T.vectorized(micro_size_k): for mk in T.vectorized(micro_size_k):
A_local[i, mk] = A_shared[warp_m * local_size_a + i, A_local[i, mk] = A_shared[warp_m * local_size_a + i, ki * micro_size_k + mk]
ki * micro_size_k + mk]
for i in T.serial(local_size_b): for i in T.serial(local_size_b):
for mk in T.vectorized(micro_size_k): for mk in T.vectorized(micro_size_k):
B_local[i, mk] = B_shared[warp_n * local_size_b + i, B_local[i, mk] = B_shared[warp_n * local_size_b + i, ki * micro_size_k + mk]
ki * micro_size_k + mk]
for i, j in T.grid(local_size_a, local_size_b): for i, j in T.grid(local_size_a, local_size_b):
for mk in T.serial(micro_size_k // dp4a_size): for mk in T.serial(micro_size_k // dp4a_size):
if use_dp4a: if use_dp4a:
T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], C_local[i * local_size_b + j])
C_local[i * local_size_b + j])
else: else:
for dp4a_idx in T.serial(dp4a_size): for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b + C_local[i * local_size_b + j] += (
j] += A_local[i, mk * dp4a_size + A_local[i, mk * dp4a_size + dp4a_idx] * B_local[j, mk * dp4a_size + dp4a_idx]
dp4a_idx] * B_local[j, mk * dp4a_size + )
dp4a_idx]
for i, j in T.grid(local_size_a, local_size_b): for i, j in T.grid(local_size_a, local_size_b):
C[by * block_M + warp_m * local_size_a + i, C[by * block_M + warp_m * local_size_a + i, bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]
bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]
return main return main
......
...@@ -5,12 +5,11 @@ import torch ...@@ -5,12 +5,11 @@ import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int, ...@@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int,
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
# Create random input tensors on the GPU # Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16) a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16)
......
...@@ -27,8 +27,8 @@ def gemv_simt( ...@@ -27,8 +27,8 @@ def gemv_simt(
): ):
assert n_partition is not None, "n_partition must be provided" assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, ( assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
"sch_outer_reduction_with_config is not implemented") )
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
...@@ -50,16 +50,15 @@ def gemv_simt( ...@@ -50,16 +50,15 @@ def gemv_simt(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
Bias: T.Tensor(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor(C_shape, out_dtype), C: T.Tensor(C_shape, out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as ( bx,
bx, by,
by, ):
):
A_local = T.alloc_local((micro_size_k,), in_dtype) A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype) B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype) accum_res = T.alloc_local((1,), accum_dtype)
...@@ -88,13 +87,12 @@ def gemv_simt( ...@@ -88,13 +87,12 @@ def gemv_simt(
) )
else: else:
for ki in T.serial(micro_size_k): for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype)
accum_dtype)
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -104,11 +102,11 @@ def gemv_simt( ...@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res[0], reduced_accum_res[0],
kr, kr,
dtype="handle", dtype="handle",
)) )
)
if kr == 0: if kr == 0:
if with_bias: if with_bias:
C[by, C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else: else:
C[by, bx * n_partition + ni] = reduced_accum_res[0] C[by, bx * n_partition + ni] = reduced_accum_res[0]
......
...@@ -4,7 +4,8 @@ from tilelang import tvm as tvm ...@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import ( from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) make_mma_swizzle_layout as make_swizzle_layout,
)
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitter,
...@@ -91,12 +92,11 @@ def tl_matmul( ...@@ -91,12 +92,11 @@ def tl_matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -104,10 +104,12 @@ def tl_matmul( ...@@ -104,10 +104,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -115,7 +117,6 @@ def tl_matmul( ...@@ -115,7 +117,6 @@ def tl_matmul(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
...@@ -125,7 +126,6 @@ def tl_matmul( ...@@ -125,7 +126,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
out_idx=[2], out_idx=[2],
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True, tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
}) },
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
...@@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform( ...@@ -285,12 +286,11 @@ def tl_matmul_weight_only_transform(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype), B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform( ...@@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
B_shared: make_swizzle_layout(B_shared), A_shared: make_swizzle_layout(A_shared),
}) B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache # Improve L2 Cache
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
...@@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform( ...@@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform(
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, block_K): for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k] A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory # Load B into shared memory
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k):
micro_size_y, micro_size_k): B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk]
B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j,
ko * (block_K // micro_size_k) + k, jj, kk]
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform( ...@@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform(
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
import bitblas import bitblas
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2]) kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
......
...@@ -6,16 +6,17 @@ import gc ...@@ -6,16 +6,17 @@ import gc
def test_tilelang_capture(): def test_tilelang_capture():
@tilelang.jit( @tilelang.jit(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},) },
)
def get_dummy_kernel(): def get_dummy_kernel():
@T.prim_func @T.prim_func
def dummy_kernel(a: T.Tensor[(1,), T.float32],): def dummy_kernel(
a: T.Tensor[(1,), T.float32],
):
with T.Kernel(1) as _: with T.Kernel(1) as _:
a[0] = 1 a[0] = 1
...@@ -36,5 +37,5 @@ def test_tilelang_capture(): ...@@ -36,5 +37,5 @@ def test_tilelang_capture():
# objgraph.show_backrefs([a_upgrade], max_depth=5) # objgraph.show_backrefs([a_upgrade], max_depth=5)
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -4,25 +4,25 @@ import tilelang.language as T ...@@ -4,25 +4,25 @@ import tilelang.language as T
def test_tilelang_intimm(): def test_tilelang_intimm():
T.int32(0x7fffffff) T.int32(0x7FFFFFFF)
T.int32(-0x7fffffff - 1) T.int32(-0x7FFFFFFF - 1)
T.uint32(0xffffffff) T.uint32(0xFFFFFFFF)
T.int64(0x7fffffffffffffff) T.int64(0x7FFFFFFFFFFFFFFF)
T.int64(-0x7fffffffffffffff - 1) T.int64(-0x7FFFFFFFFFFFFFFF - 1)
T.uint64(0xffffffffffffffff) T.uint64(0xFFFFFFFFFFFFFFFF)
a = T.int32() a = T.int32()
a & 0x7fffffff a & 0x7FFFFFFF
a = T.uint32() a = T.uint32()
a & 0xffffffff a & 0xFFFFFFFF
a = T.int64() a = T.int64()
a & 0x7fffffffffffffff a & 0x7FFFFFFFFFFFFFFF
a = T.uint64() a = T.uint64()
a & T.uint64(0xffffffffffffffff) a & T.uint64(0xFFFFFFFFFFFFFFFF)
if __name__ == '__main__': if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -5,12 +5,11 @@ import tilelang.language as T ...@@ -5,12 +5,11 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
...@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): ...@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device) accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K): for k in range(K // block_K):
if torch.all(BlockMask[i, j, k]): if torch.all(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
torch.float32) @ B[k * block_K:(k + 1) * block_K, k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
j * block_N:(j + 1) * block_N].to(torch.float32) ].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = ( ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
accu.to(torch.float16))
return ref_c return ref_c
...@@ -35,15 +34,14 @@ def blocksparse_matmul_global( ...@@ -35,15 +34,14 @@ def blocksparse_matmul_global(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -80,15 +78,14 @@ def blocksparse_matmul_shared( ...@@ -80,15 +78,14 @@ def blocksparse_matmul_shared(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -130,15 +127,14 @@ def blocksparse_matmul_local( ...@@ -130,15 +127,14 @@ def blocksparse_matmul_local(
dtype="float16", dtype="float16",
accum_dtype="float", accum_dtype="float",
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"), BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi ...@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
...@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio ...@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
# Create block mask with desired sparsity # Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K) mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity block_mask = torch.rand(mask_shape).cuda() > sparsity
......
...@@ -10,8 +10,8 @@ def alloc_var( ...@@ -10,8 +10,8 @@ def alloc_var(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype) A_shared = T.alloc_shared([block_N], dtype)
...@@ -50,8 +50,8 @@ def alloc_var_add( ...@@ -50,8 +50,8 @@ def alloc_var_add(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype) A_shared = T.alloc_shared([block_N], dtype)
...@@ -91,8 +91,8 @@ def alloc_var_with_initializer( ...@@ -91,8 +91,8 @@ def alloc_var_with_initializer(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp = T.alloc_var(dtype, init_value) tmp = T.alloc_var(dtype, init_value)
...@@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer( ...@@ -129,8 +129,8 @@ def alloc_multi_vars_with_initializer(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype), B: T.Tensor((N,), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx: with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
tmp0 = T.alloc_var(dtype, 1) tmp0 = T.alloc_var(dtype, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment