Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -7,16 +7,7 @@ from tilelang.utils import map_torch_type
@tl.jit
def tensor_null_test(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float",
with_bias=False):
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -48,12 +39,10 @@ def tensor_null_test(M,
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
kernel = tensor_null_test(
M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
kernel(a, b, c, None)
......
......@@ -206,6 +206,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
......@@ -233,19 +234,9 @@ def test_gemm_jit_kernel():
)
def run_nvrtc_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_nvrtc_kernel_do_bench(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -278,23 +269,12 @@ def run_nvrtc_kernel_do_bench(M,
def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_nvrtc_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_nvrtc_kernel_multi_stream(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -331,23 +311,12 @@ def run_nvrtc_kernel_multi_stream(M,
def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_nvrtc_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_nvrtc_dynamic_shape(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -387,21 +356,15 @@ def run_nvrtc_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def check_hopper():
......@@ -412,22 +375,7 @@ def check_hopper():
return compute_capability == (9, 0)
def convolution_im2col(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -438,9 +386,7 @@ def convolution_im2col(N,
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -449,11 +395,13 @@ def convolution_im2col(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -467,23 +415,9 @@ def convolution_im2col(N,
return main
def run_nvrtc_im2col_tma_desc(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256):
"""Test im2col TMA descriptor functionality in NVRTC backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages,
num_threads)
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")
......@@ -501,32 +435,20 @@ def run_nvrtc_im2col_tma_desc(N,
return C
ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close(
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_nvrtc_im2col_tma_desc():
"""Test im2col TMA descriptor with NVRTC backend."""
if not check_hopper():
import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor
run_nvrtc_im2col_tma_desc(
N=4,
C=64,
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256
)
def test_nvrtc_l2_persistent_map():
......@@ -543,7 +465,6 @@ def test_nvrtc_l2_persistent_map():
block_size=256,
dtype="float32",
):
@T.prim_func
def kernel(
A: T.Tensor((M, N), dtype),
......
......@@ -16,9 +16,9 @@ def matmul_kernel_jit(
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
in_dtype="float16",
out_dtype="float32",
accum_dtype="float32",
num_stages=2,
threads=128,
):
......
......@@ -144,6 +144,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C
......@@ -171,19 +172,9 @@ def test_gemm_jit_kernel():
)
def run_tvm_ffi_kernel_do_bench(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_tvm_ffi_kernel_do_bench(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -216,23 +207,12 @@ def run_tvm_ffi_kernel_do_bench(M,
def test_tvm_ffi_kernel_do_bench():
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_tvm_ffi_kernel_multi_stream(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_tvm_ffi_kernel_multi_stream(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -269,23 +249,12 @@ def run_tvm_ffi_kernel_multi_stream(M,
def test_tvm_ffi_kernel_multi_stream():
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
128, 256, 32, 2)
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
def run_tvm_ffi_dynamic_shape(M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128):
def run_tvm_ffi_dynamic_shape(
M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
program = matmul(
M,
N,
......@@ -325,21 +294,17 @@ def run_tvm_ffi_dynamic_shape(M,
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_tvm_ffi_dynamic_shape():
run_tvm_ffi_dynamic_shape(
T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128,
256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16",
"float16", 128, 256, 32, 2)
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2
)
def check_hopper():
......@@ -350,22 +315,7 @@ def check_hopper():
return compute_capability == (9, 0)
def convolution_im2col(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages,
threads,
dtype="float16",
accum_dtype="float"):
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -376,9 +326,7 @@ def convolution_im2col(N,
kernel: T.Tensor((KH, KW, C, F), dtype),
out: T.Tensor((N, OH, OW, F), dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -387,11 +335,13 @@ def convolution_im2col(N,
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.annotate_layout({
T.annotate_layout(
{
out_shared: tilelang.layout.make_swizzled_layout(out_shared),
data_shared: tilelang.layout.make_swizzled_layout(data_shared),
kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
})
}
)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
......@@ -405,23 +355,9 @@ def convolution_im2col(N,
return main
def run_tvm_ffi_im2col_tma_desc(N,
C,
H,
W,
F,
K,
S,
D,
P,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=256):
def run_tvm_ffi_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256):
"""Test im2col TMA descriptor functionality in tvm_ffi backend."""
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages,
num_threads)
program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads)
conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
......@@ -439,32 +375,20 @@ def run_tvm_ffi_im2col_tma_desc(N,
return C
ref_c = ref_program(a, b)
tilelang.testing.torch_assert_close(
out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_tvm_ffi_im2col_tma_desc():
"""Test im2col TMA descriptor with tvm_ffi backend."""
if not check_hopper():
import pytest
pytest.skip("Test requires Hopper GPU (compute capability 9.0)")
# Small test case for im2col TMA descriptor
run_tvm_ffi_im2col_tma_desc(
N=4,
C=64,
H=32,
W=32,
F=64,
K=3,
S=1,
D=1,
P=1,
block_M=64,
block_N=128,
block_K=32,
num_stages=3,
num_threads=256)
N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256
)
def test_tvm_ffi_l2_persistent_map():
......@@ -481,7 +405,6 @@ def test_tvm_ffi_l2_persistent_map():
block_size=256,
dtype="float32",
):
@T.prim_func
def kernel(
A: T.Tensor((M, N), dtype),
......@@ -506,8 +429,12 @@ def test_tvm_ffi_l2_persistent_map():
kernel = elementwise_add_with_l2_cache(M, N)
source = kernel.get_host_source()
assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, (
"Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
)
assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, (
"Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
)
# Create test tensors
a = torch.randn(M, N, dtype=torch.float32).cuda()
......
......@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
......@@ -116,7 +117,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -124,10 +124,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -135,7 +137,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -145,7 +146,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -24,7 +24,7 @@ def elementwise_add(
start_x = bx * block_N
start_y = by * block_M
for (local_y, local_x) in T.Parallel(block_M, block_N):
for local_y, local_x in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
......
......@@ -12,7 +12,6 @@ def calc_diff(x, y):
def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype):
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
......@@ -44,8 +43,7 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
C = kernel(A, B)
ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)),
B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype))
ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)), B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype))
print(C)
print(ref_c)
diff = calc_diff(C, ref_c)
......
......@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
......@@ -115,7 +116,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -123,10 +123,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -134,7 +136,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -144,7 +145,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -27,8 +27,8 @@ def gemv_simt(
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
)
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
......@@ -55,8 +55,7 @@ def gemv_simt(
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
......@@ -88,8 +87,7 @@ def gemv_simt(
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
......@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......
......@@ -95,8 +95,8 @@ def run_gemm(
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......
......@@ -6,7 +6,8 @@ from tvm import DataType
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type
......@@ -116,7 +117,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -124,10 +124,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -135,7 +137,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -145,7 +146,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......
......@@ -81,7 +81,6 @@ def tl_matmul_simt(
C: T.Tensor(C_shape, out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
......@@ -97,7 +96,6 @@ def tl_matmul_simt(
T.clear(C_local)
for ko in T.serial(K // block_K):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -109,29 +107,24 @@ def tl_matmul_simt(
for ki in T.serial((block_K // micro_size_k)):
for i in T.serial(local_size_a):
for mk in T.vectorized(micro_size_k):
A_local[i, mk] = A_shared[warp_m * local_size_a + i,
ki * micro_size_k + mk]
A_local[i, mk] = A_shared[warp_m * local_size_a + i, ki * micro_size_k + mk]
for i in T.serial(local_size_b):
for mk in T.vectorized(micro_size_k):
B_local[i, mk] = B_shared[warp_n * local_size_b + i,
ki * micro_size_k + mk]
B_local[i, mk] = B_shared[warp_n * local_size_b + i, ki * micro_size_k + mk]
for i, j in T.grid(local_size_a, local_size_b):
for mk in T.serial(micro_size_k // dp4a_size):
if use_dp4a:
T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size],
C_local[i * local_size_b + j])
T.dp4a(A_local[i, mk * dp4a_size], B_local[j, mk * dp4a_size], C_local[i * local_size_b + j])
else:
for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b +
j] += A_local[i, mk * dp4a_size +
dp4a_idx] * B_local[j, mk * dp4a_size +
dp4a_idx]
C_local[i * local_size_b + j] += (
A_local[i, mk * dp4a_size + dp4a_idx] * B_local[j, mk * dp4a_size + dp4a_idx]
)
for i, j in T.grid(local_size_a, local_size_b):
C[by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]
C[by * block_M + warp_m * local_size_a + i, bx * block_N + warp_n * local_size_b + j] = C_local[i * local_size_b + j]
return main
......
......@@ -5,7 +5,6 @@ import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -59,7 +58,8 @@ def run_gemm_with_stride_ss(M: int, N: int, K: int, block_M: int, block_N: int,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
......
......@@ -27,8 +27,8 @@ def gemv_simt(
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
)
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
......@@ -55,8 +55,7 @@ def gemv_simt(
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
with T.Kernel(T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
......@@ -88,8 +87,7 @@ def gemv_simt(
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
......@@ -104,11 +102,11 @@ def gemv_simt(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
C[by, bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......
......@@ -4,7 +4,8 @@ from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
make_mma_swizzle_layout as make_swizzle_layout,
)
from tilelang.intrinsics.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,
......@@ -96,7 +97,6 @@ def tl_matmul(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -104,10 +104,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -115,7 +117,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -125,7 +126,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -168,7 +168,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
},
)
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
......@@ -290,7 +291,6 @@ def tl_matmul_weight_only_transform(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -298,10 +298,12 @@ def tl_matmul_weight_only_transform(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
......@@ -309,19 +311,15 @@ def tl_matmul_weight_only_transform(
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k,
micro_size_y, micro_size_k):
B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j,
ko * (block_K // micro_size_k) + k, jj, kk]
for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k):
B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j, ko * (block_K // micro_size_k) + k, jj, kk]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -359,6 +357,7 @@ def tl_matmul_weight_only_transform(
def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
import bitblas
matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(matmul, out_idx=[2])
profiler = kernel.get_profiler()
......
......@@ -6,16 +6,17 @@ import gc
def test_tilelang_capture():
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},)
},
)
def get_dummy_kernel():
@T.prim_func
def dummy_kernel(a: T.Tensor[(1,), T.float32],):
def dummy_kernel(
a: T.Tensor[(1,), T.float32],
):
with T.Kernel(1) as _:
a[0] = 1
......@@ -36,5 +37,5 @@ def test_tilelang_capture():
# objgraph.show_backrefs([a_upgrade], max_depth=5)
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -4,25 +4,25 @@ import tilelang.language as T
def test_tilelang_intimm():
T.int32(0x7fffffff)
T.int32(-0x7fffffff - 1)
T.uint32(0xffffffff)
T.int64(0x7fffffffffffffff)
T.int64(-0x7fffffffffffffff - 1)
T.uint64(0xffffffffffffffff)
T.int32(0x7FFFFFFF)
T.int32(-0x7FFFFFFF - 1)
T.uint32(0xFFFFFFFF)
T.int64(0x7FFFFFFFFFFFFFFF)
T.int64(-0x7FFFFFFFFFFFFFFF - 1)
T.uint64(0xFFFFFFFFFFFFFFFF)
a = T.int32()
a & 0x7fffffff
a & 0x7FFFFFFF
a = T.uint32()
a & 0xffffffff
a & 0xFFFFFFFF
a = T.int64()
a & 0x7fffffffffffffff
a & 0x7FFFFFFFFFFFFFFF
a = T.uint64()
a & T.uint64(0xffffffffffffffff)
a & T.uint64(0xFFFFFFFFFFFFFFFF)
if __name__ == '__main__':
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,7 +5,6 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -13,11 +13,10 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if torch.all(BlockMask[i, j, k]):
accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32)
ref_c[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = (
accu.to(torch.float16))
accu += A[i * block_M : (i + 1) * block_M, k * block_K : (k + 1) * block_K].to(torch.float32) @ B[
k * block_K : (k + 1) * block_K, j * block_N : (j + 1) * block_N
].to(torch.float32)
ref_c[i * block_M : (i + 1) * block_M, j * block_N : (j + 1) * block_N] = accu.to(torch.float16)
return ref_c
......@@ -35,7 +34,6 @@ def blocksparse_matmul_global(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
......@@ -80,7 +78,6 @@ def blocksparse_matmul_shared(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
......@@ -130,7 +127,6 @@ def blocksparse_matmul_local(
dtype="float16",
accum_dtype="float",
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
@T.prim_func
......@@ -237,7 +233,8 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......@@ -284,7 +281,8 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment