Unverified Commit 95170ab7 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[Enhancement] Fix lint to improve grouped GEMM performance with TMA (#938)

* [Example]  Fix lint  to improve grouped GEMM performance with TMA

* fix lint
parent b31de0ce
...@@ -4,8 +4,6 @@ import tilelang ...@@ -4,8 +4,6 @@ import tilelang
import tilelang.language as T import tilelang.language as T
import math import math
tilelang.disable_cache()
def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False): def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
""" """
...@@ -57,6 +55,7 @@ def grouped_gemm(batch_sizes_list, ...@@ -57,6 +55,7 @@ def grouped_gemm(batch_sizes_list,
batch_sum = sum(batch_sizes_list) batch_sum = sum(batch_sizes_list)
batch_count = len(batch_sizes_list) batch_count = len(batch_sizes_list)
accum_dtype = "float32" accum_dtype = "float32"
total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
@T.prim_func @T.prim_func
def kernel( def kernel(
...@@ -68,9 +67,7 @@ def grouped_gemm(batch_sizes_list, ...@@ -68,9 +67,7 @@ def grouped_gemm(batch_sizes_list,
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
): ):
with T.Kernel( with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype) A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype) B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype) C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
...@@ -115,8 +112,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): ...@@ -115,8 +112,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i]) batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1): for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes_list[i] + 1) / padding_M) * math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype) A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype) B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype) C = torch.empty(batch_sum, M, device=device, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment