gemm_mfma.py 9.02 KB
Newer Older
1
2
3
from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
4
5
    MatrixCoreIntrinEmitter,
)
6
from tilelang.utils.language import is_shared, is_fragment, is_full_region
7
8
9
10
11
12
13
14
15
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang import language as T
from tilelang.transform.simplify import _Simplify


class GemmMFMA(GemmBase):
    def infer_layout(self, target: Target, thread_nums: int):
16
        m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
17
18
19
20
21
22
23
24
25
26
27
28
29
        warp_row_tiles = int(self.M // m_warp)
        warp_col_tiles = int(self.N // n_warp)
        mfma_emitter = MatrixCoreIntrinEmitter(
            a_dtype=self.in_dtype,
            b_dtype=self.in_dtype,
            accum_dtype=self.accum_dtype,
            a_transposed=self.trans_A,
            b_transposed=self.trans_B,
            block_row_warps=m_warp,
            block_col_warps=n_warp,
            warp_row_tiles=warp_row_tiles,
            warp_col_tiles=warp_col_tiles,
            chunk=self.chunk,
30
            k_pack=self.k_pack,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        )

        if self.is_gemm_ss():
            return {
                self.A: make_swizzled_layout(self.A),
                self.B: make_swizzled_layout(self.B),
                self.C: mfma_emitter.make_mfma_store_layout(self.C),
            }
        elif self.is_gemm_sr():
            return {
                self.A: make_swizzled_layout(self.A),
                self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
                self.C: mfma_emitter.make_mfma_store_layout(self.C),
            }
        elif self.is_gemm_rs():
            return {
                self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
                self.B: make_swizzled_layout(self.B),
                self.C: mfma_emitter.make_mfma_store_layout(self.C),
            }
        elif self.is_gemm_rr():
            return {
                self.A: mfma_emitter.make_mfma_load_layout(self.A, matrix="A"),
                self.B: mfma_emitter.make_mfma_load_layout(self.B, matrix="B"),
                self.C: mfma_emitter.make_mfma_store_layout(self.C),
            }
        else:
58
            raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
59
60

    def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: tir.Var):
61
        m_warp, n_warp = self.policy.compute_warp_partition(self.M, self.N, thread_nums, target, False)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        warp_row_tiles = int(self.M // m_warp)
        warp_col_tiles = int(self.N // n_warp)
        mfma_emitter = MatrixCoreIntrinEmitter(
            a_dtype=self.in_dtype,
            b_dtype=self.in_dtype,
            accum_dtype=self.accum_dtype,
            a_transposed=self.trans_A,
            b_transposed=self.trans_B,
            block_row_warps=m_warp,
            block_col_warps=n_warp,
            warp_row_tiles=warp_row_tiles,
            warp_col_tiles=warp_col_tiles,
            chunk=self.chunk,
            thread_var=thread_var,
76
            k_pack=self.k_pack,
77
78
79
80
81
82
83
84
85
        )

        in_dtype = self.in_dtype
        warp_rows = mfma_emitter.warp_rows
        warp_cols = mfma_emitter.warp_cols
        local_size_a = mfma_emitter.local_size_a
        local_size_b = mfma_emitter.local_size_b
        block_K = mfma_emitter.chunk
        micro_size_k = mfma_emitter.micro_size_k
86
87
88
89
90
91
92
93
94
95
96
97
        # Use region for shared-memory operands if available
        # We use region for memory input to support strided gemm
        # T.gemm(A_shared[0:128, :], B_shared, C_local)
        A_region = self.ARegion
        B_region = self.BRegion
        C_region = self.CRegion

        A_buf = A_region.buffer
        B_buf = B_region.buffer
        C_buf = C_region.buffer

        clear_accum = self.clear_accum
98
99
100

        assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"

101
102
        assert is_full_region(C_region), "Fragment output C must be a full region"

103
104
105
106
107
108
109
110
111
        if self.is_gemm_ss():

            @T.prim_func
            def _gemm_ssr() -> None:
                """
                The inner macro that loads data from shared buffers A_shared and
                B_shared into local fragments, then issues Matrix Core mfma ops,
                accumulating into C_local.
                """
112
113
                A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
                B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
114
115
                if clear_accum:
                    T.clear(C_buf)
116
                for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
117
118
119
                    # Load A into fragment
                    mfma_emitter.ldmatrix_a(
                        A_local,
120
                        A_region,
121
122
123
124
125
126
                        ki,
                    )

                    # Load B into fragment
                    mfma_emitter.ldmatrix_b(
                        B_local,
127
                        B_region,
128
129
130
131
                        ki,
                    )

                    # Perform Matrix Multiplication
132
                    mfma_emitter.mfma(A_local, B_local, C_buf, ki)
133
134
135
136
137

            # Simplify to optimize the index computing
            # Must inline let statements to simplify the analysis
            return _Simplify(_gemm_ssr, inline_let=True)
        elif self.is_gemm_sr():
138
            assert is_full_region(B_region), "Fragment input B must be a full region"
139
140
141
142
143
144
145
146

            @T.prim_func
            def _gemm_srr() -> None:
                """
                The inner macro that loads data from shared buffers A_shared and
                B_shared into local fragments, then issues Matrix Core mfma ops,
                accumulating into C_local.
                """
147
                A_local = T.alloc_local((warp_rows * local_size_a * self.k_pack), in_dtype)
148

149
150
151
                if clear_accum:
                    T.clear(C_buf)

152
                for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
153
154
155
                    # Load A into fragment
                    mfma_emitter.ldmatrix_a(
                        A_local,
156
                        A_region,
157
158
159
160
                        ki,
                    )

                    # Perform Matrix Multiplication
161
                    mfma_emitter.mfma(A_local, B_buf, C_buf, ki)
162
163
164
165
166
167
168

            # Simplify to optimize the index computing
            # Must inline let statements to simplify the analysis
            # alloc_buffers body
            # insert into parent block
            return _Simplify(_gemm_srr, inline_let=True)
        elif self.is_gemm_rs():
169
            assert is_full_region(A_region), "Fragment input A must be a full region"
170
171
172
173
174
175
176
177

            @T.prim_func
            def _gemm_rsr() -> None:
                """
                The inner macro that loads data from shared buffers A_shared and
                B_shared into local fragments, then issues Matrix Core mfma ops,
                accumulating into C_local.
                """
178
                B_local = T.alloc_local((warp_cols * local_size_b * self.k_pack), in_dtype)
179
180
                if clear_accum:
                    T.clear(C_buf)
181
                for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
182
183
184
                    # Load B into fragment
                    mfma_emitter.ldmatrix_b(
                        B_local,
185
                        B_region,
186
187
188
189
                        ki,
                    )

                    # Perform Matrix Multiplication
190
                    mfma_emitter.mfma(A_buf, B_local, C_buf, ki)
191
192
193
194
195

            # Simplify to optimize the index computing
            # Must inline let statements to simplify the analysis
            return _Simplify(_gemm_rsr, inline_let=True)
        elif self.is_gemm_rr():
196
197
            assert is_full_region(A_region), "Fragment input A must be a full region"
            assert is_full_region(B_region), "Fragment input B must be a full region"
198
199
200
201
202
203
204
205
206

            @T.prim_func
            def _gemm_rsr() -> None:
                """
                The inner macro that loads data from shared buffers A_shared and
                B_shared into local fragments, then issues Matrix Core mfma ops,
                accumulating into C_local.
                """

207
                for ki in T.serial(0, (block_K // (micro_size_k * self.k_pack))):
208
                    # Perform Matrix Multiplication
209
                    mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
210
211
212
213
214

            # Simplify to optimize the index computing
            # Must inline let statements to simplify the analysis
            return _Simplify(_gemm_rsr, inline_let=True)
        else:
215
            raise ValueError(f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
216
217
218
219
220
221
222
223
224
225
226
227

    def is_gemm_ss(self) -> bool:
        return is_shared(self.A) and is_shared(self.B)

    def is_gemm_sr(self) -> bool:
        return is_shared(self.A) and is_fragment(self.B)

    def is_gemm_rs(self) -> bool:
        return is_fragment(self.A) and is_shared(self.B)

    def is_gemm_rr(self) -> bool:
        return is_fragment(self.A) and is_fragment(self.B)