Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -4,12 +4,11 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
......
......@@ -27,9 +27,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -89,7 +89,8 @@ def run_gemm_ss(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
......@@ -159,9 +160,9 @@ def matmul_rs(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -169,9 +170,11 @@ def matmul_rs(
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
......@@ -225,7 +228,8 @@ def run_gemm_rs(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
......@@ -294,9 +298,9 @@ def matmul_sr(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -304,9 +308,11 @@ def matmul_sr(
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
......@@ -360,7 +366,8 @@ def run_gemm_sr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
......@@ -430,9 +437,9 @@ def matmul_rr(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -441,10 +448,12 @@ def matmul_rr(
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}
)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
......@@ -499,7 +508,8 @@ def run_gemm_rr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
......
......@@ -20,27 +20,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low, high = (0, 4) if is_unsigned else (-2, 2)
else:
low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse(
M,
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
else:
A = randn_semi_sparse(
M, K, dtype=torch.float32, device='cuda',
transposed=trans_A).to(map_torch_type(in_dtype))
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
A = randn_semi_sparse(M, K, dtype=torch.float32, device="cuda", transposed=trans_A).to(map_torch_type(in_dtype))
B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B
......@@ -69,24 +53,22 @@ def matmul_sp_sm90(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'uint8'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), "uint8"),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
E_shared = T.alloc_shared((block_M, block_K // E_factor), "uint8")
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_cutlass_metadata_layout(
E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="9.0", block_k=block_K),
}
)
T.disable_warp_group_reg_alloc()
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
......@@ -121,7 +103,7 @@ def matmul_sp_sm80(
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
metadata_dtype = "int32" if is_8_bit else "int16"
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
......@@ -132,20 +114,22 @@ def matmul_sp_sm80(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......@@ -216,7 +200,7 @@ def run_gemm_sp(
C = _matmul(A, B)
if 'float8' in in_dtype:
if "float8" in in_dtype:
diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}"
else:
......@@ -332,15 +316,11 @@ def test_gemm_sp_sm90():
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
True)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
True)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True)
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False,
True)
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True)
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
......@@ -352,12 +332,9 @@ def test_gemm_sp_sm80():
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
......
......@@ -34,20 +34,22 @@ def matmul(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......@@ -80,7 +82,7 @@ def run_gemm_ss(
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul(
M,
N,
......@@ -105,7 +107,8 @@ def run_gemm_ss(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
......@@ -142,26 +145,11 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
low, high = (0, 4) if is_unsigned else (-2, 2)
else:
low, high = (0, 128) if is_unsigned else (-64, 64)
A = randint_semi_sparse(
M,
K,
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda',
transposed=trans_A)
B = torch.randint(
size=(N, K) if trans_B else (K, N),
low=low,
high=high,
dtype=map_torch_type(in_dtype),
device='cuda')
A = randint_semi_sparse(M, K, low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
B = torch.randint(size=(N, K) if trans_B else (K, N), low=low, high=high, dtype=map_torch_type(in_dtype), device="cuda")
else:
A = randn_semi_sparse(
M, K, dtype=map_torch_type(in_dtype), device='cuda', transposed=trans_A)
B = torch.randn(
(N, K) if trans_B else (K, N), device='cuda',
dtype=torch.float32).to(map_torch_type(in_dtype))
A = randn_semi_sparse(M, K, dtype=map_torch_type(in_dtype), device="cuda", transposed=trans_A)
B = torch.randn((N, K) if trans_B else (K, N), device="cuda", dtype=torch.float32).to(map_torch_type(in_dtype))
return A, B
......@@ -184,8 +172,7 @@ def test_gemm_ss():
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64,
2)
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# tfloat32 test
......@@ -222,10 +209,10 @@ def matmul_rs(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -233,11 +220,13 @@ def matmul_rs(
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......@@ -271,7 +260,7 @@ def run_gemm_rs(
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_rs(
M,
N,
......@@ -296,7 +285,8 @@ def run_gemm_rs(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
......@@ -376,10 +366,10 @@ def matmul_sr(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -387,11 +377,13 @@ def matmul_sr(
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......@@ -425,7 +417,7 @@ def run_gemm_sr(
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_sr(
M,
N,
......@@ -450,7 +442,8 @@ def run_gemm_sr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
......@@ -531,10 +524,10 @@ def matmul_rr(
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -543,12 +536,14 @@ def matmul_rr(
A_frag = T.alloc_fragment(A_frag_shape, in_dtype)
B_frag = T.alloc_fragment(B_frag_shape, in_dtype)
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, arch="8.0"),
}
)
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......@@ -583,7 +578,7 @@ def run_gemm_rr(
num_stages=3,
num_threads=128,
):
metadata_dtype = 'int32' if ('8' in in_dtype) else 'int16'
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
program = matmul_rr(
M,
N,
......@@ -608,7 +603,8 @@ def run_gemm_rr(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
A, B = generate_dense_input(M, N, K, trans_A, trans_B, in_dtype)
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K, arch="8.0")
C_sp = kernel(A_sparse, E, B)
......
......@@ -11,22 +11,14 @@ def _check(original, transformed):
mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True)
def test_trival_pipeline():
@T.prim_func
def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
1,
annotations={
"software_pipeline_stage": [0, 1],
"software_pipeline_order": [0, 1]
}):
for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}):
with T.block():
T.reads(A[tx, i])
T.writes(C[tx, i])
......
......@@ -21,10 +21,8 @@ def _check(original, transformed):
def test_cluster_planning():
@T.prim_func
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor(
(1024, 1024), "float16")):
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16")
B_shared = T.alloc_shared((32, 128), "float16")
......@@ -41,8 +39,7 @@ def test_cluster_planning():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor(
(1024, 1024), "float16")):
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
T.func_attr({"clusterIdx.y": T.int32(2)})
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16")
......
......@@ -9,7 +9,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_N = 64
num_stages = 0
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
batch = T.int32(batch)
heads = T.int32(heads)
......@@ -24,7 +24,6 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
......@@ -36,37 +35,36 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
T.copy(K[bz, by, k * block_N : (k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.Tensor([block_M, dim], dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
acc_o: T.Tensor([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
V: T.Tensor(shape, dtype),
V_shared: T.Tensor([block_M, dim], dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
acc_o: T.Tensor([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(V[bz, by, k * block_N : (k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Tensor([block_M, block_N], accum_dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
scores_max: T.Tensor([block_M], accum_dtype),
scores_max_prev: T.Tensor([block_M], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
scores_sum: T.Tensor([block_M], accum_dtype),
logsum: T.Tensor([block_M], accum_dtype),
acc_s: T.Tensor([block_M, block_N], accum_dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
scores_max: T.Tensor([block_M], accum_dtype),
scores_max_prev: T.Tensor([block_M], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
scores_sum: T.Tensor([block_M], accum_dtype),
logsum: T.Tensor([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -92,22 +90,21 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
@T.macro
def Rescale(
acc_o: T.Tensor([block_M, dim], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
acc_o: T.Tensor([block_M, dim], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -122,7 +119,7 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.copy(Q[bz, by, bx * block_M : (bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -131,19 +128,18 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)
)
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
T.copy(O_shared, Output[bz, by, bx * block_M : (bx + 1) * block_M, :])
return main
......
......@@ -22,7 +22,6 @@ def _check(original, transformed):
def test_lower_fence_proxy():
@T.prim_func
def before():
with T.Kernel(8):
......@@ -30,12 +29,15 @@ def test_lower_fence_proxy():
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.call_intrin(
"handle",
tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
@T.prim_func
def after():
......@@ -44,19 +46,21 @@ def test_lower_fence_proxy():
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2:i * 2 + 2] = T.Broadcast(T.float32(0), 2)
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
T.fence_proxy_async()
T.call_intrin("handle", tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.call_intrin(
"handle",
tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
_check(before, after)
def test_async_to_generic_no_double_fence():
@T.prim_func
def before():
with T.Kernel(8):
......@@ -90,7 +94,6 @@ def test_async_to_generic_no_double_fence():
def test_proxy_hint_override():
@T.prim_func
def before():
with T.Kernel(8):
......@@ -123,7 +126,6 @@ def test_proxy_hint_override():
def test_tma_store_sync_injection():
@T.prim_func
def before():
with T.Kernel(8):
......@@ -154,7 +156,6 @@ def test_tma_store_sync_injection():
def test_wgmma_marked_async():
@T.prim_func
def before():
with T.Kernel(1):
......@@ -164,9 +165,24 @@ def test_wgmma_marked_async():
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
T.ptx_wgmma_ss("float16", "m64n64k16", T.bool(True), T.bool(True), "fp16", "fp16",
"fp16", desc_a.data, T.int32(0), desc_b.data, T.int32(0), C_local.data,
T.int32(0), T.bool(True), 1, 1)
T.ptx_wgmma_ss(
"float16",
"m64n64k16",
T.bool(True),
T.bool(True),
"fp16",
"fp16",
"fp16",
desc_a.data,
T.int32(0),
desc_b.data,
T.int32(0),
C_local.data,
T.int32(0),
T.bool(True),
1,
1,
)
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
......
......@@ -35,26 +35,25 @@ def test_inject_set_max_nreg():
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1,
0, 2, 2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else:
# Consumer branch - should have set_max_nreg(240, 1)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
# Apply the InjectSetMaxNReg pass
func = before
......@@ -67,15 +66,18 @@ def test_inject_set_max_nreg():
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
if (
isinstance(stmt, tvm.tir.Evaluate)
and hasattr(stmt.value, "op")
and hasattr(stmt.value.op, "name")
and stmt.value.op.name == "tl.set_max_nreg"
):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert len(set_max_nreg_calls
) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
assert len(set_max_nreg_calls) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg test passed!")
......@@ -116,16 +118,18 @@ def test_inject_set_max_nreg_no_set_max_nreg():
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
if (
isinstance(stmt, tvm.tir.Evaluate)
and hasattr(stmt.value, "op")
and hasattr(stmt.value.op, "name")
and stmt.value.op.name == "tl.set_max_nreg"
):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert len(
set_max_nreg_calls
) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
assert len(set_max_nreg_calls) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg with no_set_max_nreg test passed!")
......
......@@ -8,17 +8,21 @@ import pytest
auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [
(64, 64, 32, 128, 8, "float16"),
])
@pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype",
[
(64, 64, 32, 128, 8, "float16"),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")
def before():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
......@@ -26,58 +30,62 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
for vec in T.Parallel(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b), t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b) < K and bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) +
t // (block_N // vec_load_b), bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) + vec],
T.float16(0))
B_shared[
i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
],
T.float16(0),
)
return tvm.IRModule({'main': main})
return tvm.IRModule({"main": main})
def after():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
if (k * block_K + i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b)) * N % vec_load_b == 0:
if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0:
for vec in T.vectorized(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b), t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec] = T.if_then_else(
k * block_K + i *
(threads * vec_load_b // block_N) + t //
(block_N // vec_load_b) < K and bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) +
t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
B_shared[
i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
],
T.float16(0),
)
else:
for vec in T.serial(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b), t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec] = T.if_then_else(
k * block_K + i *
(threads * vec_load_b // block_N) + t //
(block_N // vec_load_b) < K and bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) +
t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
B_shared[
i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
],
T.float16(0),
)
return tvm.IRModule({'main': main})
return tvm.IRModule({"main": main})
with tvm.target.Target(auto_target):
mod = tvm.tir.transform.BindTarget(auto_target)(before())
......
......@@ -8,7 +8,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),):
def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......@@ -16,17 +18,18 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared[tid, j] = A[tid + M_offset, j + N_offset]
@T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),):
def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset])
T.reads(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N):
A_shared[tid, j] = T.if_then_else(
j + N_offset < N,
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset],
T.float32(0)), T.float32(0))
j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0)
)
return main, expected
......@@ -41,13 +44,13 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.dynamic('num_tokens')
num_tokens = T.dynamic("num_tokens")
num_threads = 128
@T.prim_func
def main(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int')
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
......@@ -59,24 +62,22 @@ def issue_1013_buggy_kernel():
@T.prim_func
def expected(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int')
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += T.Cast("int32",
T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64,
N: int = 64,
M_offset: int = 2,
N_offset: int = 2):
def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),):
def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......@@ -85,17 +86,18 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64,
T.atomic_add(A[tid + M_offset, j + N_offset], 1)
@T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),):
def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset])
T.reads(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N):
A_shared[tid, j] = T.if_then_else(
j + N_offset < N,
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset],
T.float32(0)), T.float32(0))
j + N_offset < N, T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset], T.float32(0)), T.float32(0)
)
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M:
......@@ -115,17 +117,21 @@ def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: in
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),):
def main(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding()
for j in T.serial(N):
A[tid + M_offset, j + N_offset] = 1
@T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),):
def expected(
A: T.Tensor((M, N), dtype=dtype),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding()
T.writes(A[tid + M_offset, N_offset:N + N_offset])
T.writes(A[tid + M_offset, N_offset : N + N_offset])
for j in T.serial(N):
if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M:
......
......@@ -9,7 +9,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len = 8
@T.prim_func
def main(A: T.Tensor((M, N, vec_len), dtype="float32"),):
def main(
A: T.Tensor((M, N, vec_len), dtype="float32"),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
......@@ -18,7 +20,9 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared[tid, j, v] = A[tid, j, v]
@T.prim_func
def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),):
def expected(
A: T.Tensor((M, N, vec_len), dtype="float32"),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
......
......@@ -8,12 +8,10 @@ def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.LetInline()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), True)
def test_let_binding():
@T.prim_func
def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
for i in range(128):
......@@ -34,7 +32,6 @@ def test_let_binding():
def test_parallel_scope():
@T.prim_func
def before(A: T.Tensor((128,), "float32")):
for i in T.Parallel(128):
......
......@@ -24,7 +24,6 @@ def _check(original, transformed):
def test_lower_hopper_intrin_barrier():
@T.prim_func
def before():
with T.Kernel(8):
......@@ -37,18 +36,10 @@ def test_lower_hopper_intrin_barrier():
v_1 = T.launch_thread("threadIdx.x", 128)
T.evaluate(tir.Call("handle", "tir.create_barriers", [4]))
with T.If(v_1 == 0), T.Then():
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(0), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(1), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(2), 128]))
T.evaluate(
tir.Call("handle", "tir.ptx_init_barrier_thread_count",
[T.get_mbarrier(3), 128]))
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(0), 128]))
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(1), 128]))
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(2), 128]))
T.evaluate(tir.Call("handle", "tir.ptx_init_barrier_thread_count", [T.get_mbarrier(3), 128]))
T.evaluate(tir.Call("handle", "tir.tvm_storage_sync", ["shared"]))
_check(before, after)
......
......@@ -8,63 +8,69 @@ import pytest
auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize("block_M, block_N, block_K, threads, vec_load_b, dtype", [
(64, 64, 32, 128, 8, "float16"),
])
@pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype",
[
(64, 64, 32, 128, 8, "float16"),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
N = tvm.te.var("n")
K = tvm.te.var("k")
def before():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(B[k * block_K, bx * block_N], B_shared)
return tvm.IRModule({'main': main})
return tvm.IRModule({"main": main})
def after():
@T.prim_func
def main(B: T.Tensor((K, N), dtype),):
def main(
B: T.Tensor((K, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
t = thread_bindings
for i in T.unroll(0, block_N * block_K // (threads * vec_load_b)):
if (k * block_K + i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b)) * N % vec_load_b == 0:
if (k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b)) * N % vec_load_b == 0:
for vec in T.vectorized(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b), t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec] = T.if_then_else(
k * block_K + i *
(threads * vec_load_b // block_N) + t //
(block_N // vec_load_b) < K and bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) +
t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
B_shared[
i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
],
T.float16(0),
)
else:
for vec in T.serial(vec_load_b):
B_shared[i * (threads * vec_load_b // block_N) + t //
(block_N // vec_load_b), t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec] = T.if_then_else(
k * block_K + i *
(threads * vec_load_b // block_N) + t //
(block_N // vec_load_b) < K and bx * block_N + t %
(block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[k * block_K + i * (threads * vec_load_b // block_N) +
t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0))
B_shared[
i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
] = T.if_then_else(
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b) < K
and bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) < N,
B[
k * block_K + i * (threads * vec_load_b // block_N) + t // (block_N // vec_load_b),
bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec,
],
T.float16(0),
)
return tvm.IRModule({'main': main})
return tvm.IRModule({"main": main})
with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(before())
......
......@@ -80,7 +80,6 @@ def test_target_host_removed():
@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
......@@ -102,7 +101,6 @@ def test_internal_subroutine_call():
@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm", host="llvm")})
......@@ -121,7 +119,8 @@ def test_internal_subroutine_call():
subroutine_call_op = compute_scope.body.value.op
assert isinstance(subroutine_call_op, tvm.ir.GlobalVar), (
f"The main function's CallNode should use the subroutine's GLobalVar as the operation, "
f"but instead has an operation of type {subroutine_call_op}")
f"but instead has an operation of type {subroutine_call_op}"
)
def test_subroutine_call_to_externally_visible_subroutine():
......@@ -135,7 +134,6 @@ def test_subroutine_call_to_externally_visible_subroutine():
@I.ir_module
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
......@@ -154,11 +152,10 @@ def test_subroutine_call_to_externally_visible_subroutine():
assert subroutine_compute_scope is not None
subroutine_call_op = main_compute_scope.body.value.op
assert (
isinstance(subroutine_call_op, tvm.ir.Op) and
subroutine_call_op.name == "tir.tvm_call_cpacked"
), (f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f"but instead has an operation of type {subroutine_call_op}")
assert isinstance(subroutine_call_op, tvm.ir.Op) and subroutine_call_op.name == "tir.tvm_call_cpacked", (
f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', "
f"but instead has an operation of type {subroutine_call_op}"
)
@tilelang.testing.requires_llvm
......@@ -167,10 +164,10 @@ def test_function_call_with_wrong_argument_count():
@T.prim_func
def func(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
D: T.Buffer([16, 16], "int32"),
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
D: T.Buffer([16, 16], "int32"),
):
pass
......
......@@ -31,7 +31,6 @@ block_K = 32
def test_multi_version_buffer():
@T.prim_func
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
......@@ -49,21 +48,27 @@ def test_multi_version_buffer():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
k * 32, by * 64)
k * 32,
by * 64,
)
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
bx * 64, k * 32)
bx * 64,
k * 32,
)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
@T.prim_func
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
......@@ -82,31 +87,32 @@ def test_multi_version_buffer():
for k in T.serial(16, annotations={"num_stages": T.int32(3)}):
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
if v == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3,
2, 0), 0,
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, k * 32)
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
)
_check(before, after)
def test_multi_version_buffer_with_let():
@T.prim_func
def before(scales: T.Tensor((4,), "float32")):
with T.block("root"):
......
......@@ -19,10 +19,8 @@ def _check(original, transformed):
def test_simple_pipeline():
@T.prim_func
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor(
(1024, 1024), "float32")):
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32")
......@@ -39,8 +37,7 @@ def test_simple_pipeline():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor(
(1024, 1024), "float32")):
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32")
......@@ -49,14 +46,13 @@ def test_simple_pipeline():
T.clear(C_local)
for ko in T.serial(
32,
annotations={
"software_pipeline_async_stages": [T.int32(0)],
"software_pipeline_order": [T.int32(0), T.int32(1),
T.int32(2)],
"software_pipeline_stage": [T.int32(3), T.int32(3),
T.int32(3)]
}):
32,
annotations={
"software_pipeline_async_stages": [T.int32(0)],
"software_pipeline_order": [T.int32(0), T.int32(1), T.int32(2)],
"software_pipeline_stage": [T.int32(3), T.int32(3), T.int32(3)],
},
):
T.copy(A[by * 128, ko * 32], A_shared)
T.copy(B[ko * 32, bx * 128], B_shared)
T.gemm(A_shared, B_shared, C_local)
......
......@@ -8,14 +8,13 @@ def modify(
with_B: bool = False,
with_bias: bool = False,
):
@T.prim_func
def main(
A: T.Tensor((64, 64)),
B: T.Tensor((64, 64)),
C: T.Tensor((64, 64)),
D: T.Tensor((64, 64)),
bias: T.Tensor((64, 64)),
A: T.Tensor((64, 64)),
B: T.Tensor((64, 64)),
C: T.Tensor((64, 64)),
D: T.Tensor((64, 64)),
bias: T.Tensor((64, 64)),
):
if with_B:
if with_bias:
......@@ -42,7 +41,6 @@ def test_modify(with_B=False, with_bias=False):
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a: T.handle,
......@@ -76,6 +74,7 @@ def test_matmul():
kernel = tl.compile(mod["main"], out_idx=[2])
import torch
a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
b = torch.randn(1024, 1024, dtype=torch.float16).cuda().half()
c = kernel(a, b)
......
......@@ -11,11 +11,7 @@ def run_passes(func: tvm.tir.PrimFunc):
cuda_target = tvm.target.Target("cuda", host="llvm")
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
"global_symbol": "test",
"target": cuda_target
}))(
mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}))(mod)
mod = tvm.tir.transform.AnnotateDeviceRegions()(mod)
mod = tvm.tir.transform.SplitHostDevice()(mod)
......@@ -24,7 +20,6 @@ def run_passes(func: tvm.tir.PrimFunc):
@tilelang.testing.requires_cuda
def test_sync_if_with_same_index():
@T.prim_func(check_well_formed=False)
def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x")
......@@ -47,7 +42,6 @@ def test_sync_if_with_same_index():
@tilelang.testing.requires_cuda
def test_sync_read_thread_id_independent_location():
@T.prim_func
def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None:
threadIdx_x = T.env_thread("threadIdx.x")
......@@ -71,7 +65,6 @@ def test_sync_read_thread_id_independent_location():
@tilelang.testing.requires_cuda
def test_sync_shared():
@T.prim_func(private=True)
def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 1)
......@@ -113,7 +106,6 @@ def test_sync_shared():
@tvm.testing.requires_cuda
def test_sync_let_stmt():
@T.prim_func(private=True)
def func(A: T.Buffer((16 * 512), "float32")):
blockIdx_x = T.launch_thread("blockIdx.x", 16)
......@@ -136,9 +128,9 @@ def test_sync_let_stmt():
in_thread_A_temp_1[0] = A_temp
cross_thread_A_temp_1 = T.Buffer((1,), data=cross_thread_A_temp, scope="local")
with T.attr(
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
"reduce_scope",
T.reinterpret("handle", T.uint64(0)),
):
T.tvm_thread_allreduce(
T.uint32(1),
......@@ -190,16 +182,19 @@ def test_sync_let_stmt():
@tilelang.testing.requires_cuda
def test_sync_shared_dyn_stmatrix_loop_hoist():
@T.prim_func
def func():
buf_dyn_shmem = T.alloc_buffer((98304,), "uint8", scope="shared.dyn")
tx = T.launch_thread("threadIdx.x", 384)
for i in T.unroll(8):
off = (
i // 4 * 8192 + tx // 32 * 1024 + tx % 16 * 64 +
(tx % 8 // 4 + i % 4 // 2) % 2 * 32 + (tx % 4 // 2 + i % 2) % 2 * 16 +
(tx % 32 // 16 + tx % 2) % 2 * 8)
i // 4 * 8192
+ tx // 32 * 1024
+ tx % 16 * 64
+ (tx % 8 // 4 + i % 4 // 2) % 2 * 32
+ (tx % 4 // 2 + i % 2) % 2 * 16
+ (tx % 32 // 16 + tx % 2) % 2 * 8
)
T.evaluate(
T.call_intrin(
"handle",
......@@ -214,7 +209,8 @@ def test_sync_shared_dyn_stmatrix_loop_hoist():
2,
),
T.int32(2),
))
)
)
mod = tvm.IRModule({"main": func})
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment