example_grouped_gemm_fwd.py 7.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import argparse
import tilelang
import tilelang.language as T
import math


def torch_gmm(a, b, batch_sizes, batch_offsets_tensor, trans_b=False):
    """
    Perform grouped matrix multiplication using PyTorch.

    Args:
        a (torch.Tensor): Input tensor of shape (N, K).
        b (torch.Tensor): Input tensor of shape (G, K, M).
        batch_sizes (torch.Tensor): 1D tensor containing the sizes of each group.

    Returns:
        torch.Tensor: Resulting tensor after grouped matrix multiplication.
    """
    assert a.shape[0] == sum(batch_sizes), "Sum of batch_sizes must equal the first dimension of a"
21
    assert b.shape[0] == len(batch_sizes), "The first dimension of b must match the length of batch_sizes"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

    # Initialize output tensor
    output = torch.empty((sum(batch_sizes), b.shape[2]), device=a.device, dtype=a.dtype)

    # Perform grouped GEMM
    start = 0
    for i, size in enumerate(batch_sizes):
        end = start + size
        part_a = a[start:end]
        part_b = b[i].transpose(0, 1) if trans_b else b[i]
        part_out = torch.mm(part_a, part_b)
        output[start:end] = part_out
        start = end

    return output


39
@tilelang.jit(out_idx=[2])
40
def grouped_gemm(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype="float16"):
41
42
43
44
45
46
47
48
    """
    args:
        a (torch.Tensor): Input tensor of shape (M, K).
        b (torch.Tensor): Input tensor of shape (G, K, N).
    """
    batch_sum = sum(batch_sizes_list)
    batch_count = len(batch_sizes_list)
    accum_dtype = "float32"
49
    total_m_blocks = sum((size + block_M - 1) // block_M for size in batch_sizes_list)
50
51
52

    @T.prim_func
    def kernel(
53
54
55
56
57
58
        A: T.Tensor([batch_sum, K], dtype),  # type: ignore
        B: T.Tensor([batch_count, K, N], dtype),  # type: ignore
        C: T.Tensor([batch_sum, N], dtype),  # type: ignore
        batch_sizes: T.Tensor([batch_count], "int32"),  # type: ignore
        batch_offsets: T.Tensor([batch_count], "int32"),  # type: ignore
        batch_padded_offsets: T.Tensor([batch_count], "int32"),  # type: ignore
59
    ):
60
        with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by):
61
62
63
64
65
66
67
68
69
            A_shared = T.alloc_shared([block_M, block_K], dtype)
            B_shared = T.alloc_shared([block_K, block_N], dtype)
            C_local = T.alloc_fragment([block_M, block_N], accum_dtype)
            cur_batch_idx = T.alloc_local([1], "int32")
            cur_batch_size = T.alloc_local([1], "int32")

            m_start_padded = bx * block_M

            for i in range(batch_count):
70
                in_cur_batch_idx = m_start_padded >= batch_padded_offsets[i]
71
72
73
                cur_batch_idx[0] = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx[0])

            cur_batch_size[0] = batch_sizes[cur_batch_idx[0]]
74
75
            m_start = m_start_padded - batch_padded_offsets[cur_batch_idx[0]] + batch_offsets[cur_batch_idx[0]]
            actual_rows = T.max(0, T.min(block_M, cur_batch_size[0] + batch_padded_offsets[cur_batch_idx[0]] - m_start_padded))
76
77
78

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
79
80
                T.copy(A[m_start : m_start + block_M, k * block_K : (k + 1) * block_K], A_shared)
                T.copy(B[cur_batch_idx[0], k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], B_shared)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
                T.gemm(A_shared, B_shared, C_local)

            for i, j in T.Parallel(block_M, block_N):
                with T.If(i < actual_rows), T.Then():
                    C[m_start + i, by * block_N + j] = C_local[i, j]

    return kernel


def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype):
    batch_sum = sum(batch_sizes_list)
    batch_count = len(batch_sizes_list)
    batch_offsets_list = [0]
    batch_padded_offsets_list = [0]
    for i in range(batch_count - 1):
        batch_offsets_list.append(batch_offsets_list[-1] + batch_sizes_list[i])
    for i in range(batch_count - 1):
98
        batch_padded_offsets_list.append(batch_padded_offsets_list[-1] + math.ceil((batch_sizes_list[i]) / padding_M) * padding_M)
99
100
101
102
103
104
105
106
107
108
109
110
    A = torch.randn(batch_sum, K, device=device, dtype=dtype)
    B = torch.randn(batch_count, K, M, device=device, dtype=dtype)
    C = torch.empty(batch_sum, M, device=device, dtype=dtype)
    batch_sizes = torch.tensor(batch_sizes_list, device=device, dtype=torch.int32)
    batch_offsets = torch.tensor(batch_offsets_list, device=device, dtype=torch.int32)
    batch_padded_offsets = torch.tensor(batch_padded_offsets_list, device=device, dtype=torch.int32)
    # print(batch_sizes_tensor)
    # print(batch_offsets_tensor)
    # print(batch_padded_offsets_tensor)
    return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets


111
def run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages=2, threads=128, profile=False):
112
113
    padding_M = block_M
    batch_sum = sum(batch_sizes_list)
114
    kernel = grouped_gemm(tuple(batch_sizes_list), K, M, block_M, block_N, block_K, num_stages, threads)
115
116
117
118
119
    # print(kernel.get_kernel_source())

    device = torch.device("cuda")
    dtype = torch.float16

120
    A, B, C, batch_sizes, batch_offsets, batch_padded_offsets = construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype)
121
122
123
124
125
126
127
128
129
130
131
    out = kernel(A, B, batch_sizes, batch_offsets, batch_padded_offsets)
    ref_output = torch_gmm(A, B, batch_sizes, batch_offsets, trans_b)
    # print(out)
    # print(ref_output)
    if torch.allclose(out, ref_output, rtol=0.01, atol=0.01):
        print("✅ Tilelang and Torch match")
    else:
        print("❌ Tilelang and Torch mismatch")

    if profile:
        profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
132
        latency = profiler.do_bench(warmup=500, input_tensors=[A, B, batch_sizes, batch_offsets, batch_padded_offsets])
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        print(f"Latency: {latency} ms")
        print(f"TFlops: {batch_sum * K * M * 2 / latency * 1e-9} TFlops")


def test_grouped_gemm():
    run_tilelang_grouped_gemm([64], 8192, 8192, 64, 64, 64, False)
    run_tilelang_grouped_gemm([64, 128, 256], 8192, 8192, 64, 64, 64, False)
    run_tilelang_grouped_gemm([63], 8192, 8192, 64, 64, 64, False)
    run_tilelang_grouped_gemm([100, 200, 300, 400], 8192, 8192, 64, 64, 64, False)
    run_tilelang_grouped_gemm([63, 77, 111, 280], 8192, 8192, 64, 64, 64, False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
147
148
149
150
151
    parser.add_argument("--batch_sizes", type=str, default="64, 128", help="comma-separated batch sizes")
    parser.add_argument("--K", type=int, default=8192, help="reduce dim")
    parser.add_argument("--M", type=int, default=8192, help="output dim")
    parser.add_argument("--trans_b", action="store_true", help="transpose B")
    parser.add_argument("--profile", action="store_true", help="profile")
152
153
154
155
156
157
158
159
160
161
162
    args = parser.parse_args()

    batch_sizes_list = [int(x) for x in args.batch_sizes.split(",")]
    K, M, trans_b = args.K, args.M, args.trans_b

    block_M = 64
    block_N = 128
    block_K = 64
    num_stages = 2
    threads = 256

163
    run_tilelang_grouped_gemm(batch_sizes_list, K, M, block_M, block_N, block_K, trans_b, num_stages, threads, profile=args.profile)