Commit de028927 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Dev] Add grouped GEMM backward example scripts (#515)

* Introduced `example_grouped_gemm_fwd.py` and `example_grouped_gemm_bwd.py` to demonstrate grouped matrix multiplication with forward and backward operations.
* Implemented functions for grouped GEMM, input construction, and validation against PyTorch's implementation.
* Added command-line argument parsing for flexible input configuration, including batch sizes and matrix dimensions.
* Included a test function to validate the functionality with various input scenarios.
parent fb801940
import torch
import math
import argparse
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def grouped_gemm_fwd(batch_sum,
batch_count,
K,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, K], dtype), # type: ignore
B: T.Tensor([batch_count, K, N], dtype), # type: ignore
C: T.Tensor([batch_sum, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
batch_padded_offsets: T.Tensor([batch_count], "int32"), # type: ignore
):
with T.Kernel(
T.ceildiv(batch_sum, block_M) + batch_count, T.ceildiv(N, block_N),
threads=threads) as (bx, by):
A_shared = T.alloc_shared([block_M, block_K], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
cur_batch_idx = T.alloc_local([1], "int32")
cur_batch_size = T.alloc_local([1], "int32")
m_start_padded = bx * block_M
for i in range(batch_count):
in_cur_batch_idx = (m_start_padded >= batch_padded_offsets[i])
cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])
cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[
cur_batch_idx[0]]
actual_rows = T.max(
0,
T.min(block_M,
cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[m_start:m_start + block_M, k * block_K:(k + 1) * block_K], A_shared)
T.copy(
B[cur_batch_idx[0], k * block_K:(k + 1) * block_K,
by * block_N:(by + 1) * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
for i, j in T.Parallel(block_M, block_N):
with T.If(i < actual_rows), T.Then():
C[m_start + i, by * block_N + j] = C_local[i, j]
return kernel
class _GroupedGEMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b, batch_sizes):
block_M = 64
block_N = 64
block_K = 64
padding_M = block_M
num_stages = 2
threads = 128
batch_sum = a.shape[0]
batch_count = b.shape[0]
K = a.shape[1]
N = b.shape[2]
assert a.shape[1] == b.shape[1]
assert batch_sizes.shape[0] == batch_count
assert batch_sizes.sum() == batch_sum
batch_offsets_list = [0]
batch_padded_offsets_list = [0]
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes[i] + 1) / padding_M) *
padding_M)
batch_offsets = torch.tensor(batch_offsets_list, device=a.device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(
batch_padded_offsets_list, device=a.device, dtype=torch.int32)
program = grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K,
num_stages, threads)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
o = kernel(a, b, batch_sizes, batch_offsets, batch_padded_offsets)
ctx.save_for_backward(a, b, batch_sizes, batch_offsets)
ctx.batch_sum = batch_sum
ctx.batch_count = batch_count
ctx.K = K
return o
@staticmethod
def backward(ctx, grad_output):
block_M = 64
block_N = 64
block_K = 64
num_stages = 2
threads = 128
M = ctx.K
N = grad_output.shape[1]
A, B, batch_sizes, batch_offsets = ctx.saved_tensors
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
A, B, batch_sizes = [maybe_contiguous(x) for x in (A, B, batch_sizes)]
program = grouped_gemm_bwd(ctx.batch_sum, ctx.batch_count, M, N, block_M, block_N, block_K,
num_stages, threads)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
"tl.disable_tma_lower": True,
"tl.disable_warp_specialized": True
})
dB = kernel(A, grad_output, batch_sizes, batch_offsets)
return None, dB, None
def ref_program(a, b, batch_sizes):
assert a.shape[0] == sum(batch_sizes)
assert b.shape[0] == len(batch_sizes)
output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype)
start = 0
a_list = []
b_list = []
for i, size in enumerate(batch_sizes):
end = start + size
part_a = a[start:end]
part_b = b[i]
output[start:end] = torch.mm(part_a, part_b)
a_list.append(part_a)
b_list.append(part_b)
start = end
return output
def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
batch_sum = sum(batch_sizes_list)
batch_count = len(batch_sizes_list)
batch_offsets_list = [0]
batch_padded_offsets_list = [0]
for i in range(batch_count - 1):
batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
for i in range(batch_count - 1):
batch_padded_offsets_list.append(batch_padded_offsets_list[-1] +
math.ceil((batch_sizes_list[i] + 1) / padding_M) *
padding_M)
A = torch.randn(batch_sum, K, device=device, dtype=dtype)
B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
C = torch.empty(batch_sum, M, device=device, dtype=dtype)
batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32)
batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32)
batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32)
# print(batch_sizes_tensor)
# print(batch_offsets_tensor)
# print(batch_padded_offsets_tensor)
return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets
def grouped_gemm_bwd(batch_sum,
batch_count,
M,
N,
block_M,
block_N,
block_K,
num_stages=2,
threads=128,
dtype="float16"):
"""
args:
a (torch.Tensor): Input tensor of shape (M, K).
b (torch.Tensor): Input tensor of shape (G, K, N).
"""
accum_dtype = "float32"
@T.prim_func
def kernel(
A: T.Tensor([batch_sum, M], dtype), # type: ignore
B: T.Tensor([batch_sum, N], dtype), # type: ignore
C: T.Tensor([batch_count, M, N], dtype), # type: ignore
batch_sizes: T.Tensor([batch_count], "int32"), # type: ignore
batch_offsets: T.Tensor([batch_count], "int32"), # type: ignore
):
with T.Kernel(
T.ceildiv(M, block_M), T.ceildiv(N, block_N), batch_count,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared([block_K, block_M], dtype)
B_shared = T.alloc_shared([block_K, block_N], dtype)
C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(batch_sizes[bz], block_K), num_stages=num_stages):
for i, j in T.Parallel(block_K, block_M):
A_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], A[batch_offsets[bz] + k * block_K + i,
bx * block_M + j], 0)
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = T.if_then_else(
i < batch_sizes[bz], B[batch_offsets[bz] + k * block_K + i,
by * block_N + j], 0)
T.gemm(A_shared, B_shared, C_local, transpose_A=True)
T.copy(C_local, C[bz, bx * block_M, by * block_N])
return kernel
def run_tilelang_grouped_gemm(batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages=2,
threads=128,
profile=False):
padding_M = block_M
device = torch.device("cuda")
dtype = torch.float16
A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(
batch_sizes_list, K, M, False, padding_M, device, dtype)
A.requires_grad_(False)
B.requires_grad_(True)
O_ref = ref_program(A, B, batch_sizes)
dO = torch.randn_like(O_ref)
O_ref.backward(dO, retain_graph=True)
dB_ref, B.grad = B.grad.clone(), None
GroupedGEMM = _GroupedGEMM.apply
O = GroupedGEMM(A, B, batch_sizes)
O.backward(dO, retain_graph=True)
dB, B.grad = B.grad.clone(), None
if (
torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) and \
torch.allclose(dB, dB_ref, rtol=1e-2, atol=1e-2)
):
print("✅ Tilelang and Torch match")
else:
print("❌ Tilelang and Torch mismatch")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_sizes', type=str, default="64, 128", help='comma-separated batch sizes')
parser.add_argument('--K', type=int, default=8192, help='reduce dim')
parser.add_argument('--M', type=int, default=8192, help='output dim')
parser.add_argument('--trans_b', action="store_true", help="transpose B")
parser.add_argument('--profile', action="store_true", help="profile")
args = parser.parse_args()
batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
K, M, trans_b = args.K, args.M, args.trans_b
block_M = 64
block_N = 128
block_K = 64
num_stages = 2
threads = 256
run_tilelang_grouped_gemm(
batch_sizes_list,
K,
M,
block_M,
block_N,
block_K,
trans_b,
num_stages,
threads,
profile=args.profile)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment