gemm.h 12.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <fmha/utils.h>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {

    // The data type.
    using Data_type = Data_type_;
    // default input type
    using Input_type_ = Data_type_;
    // Does it store the array of elements.
    static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
    // The number of elements.
    static constexpr int NUM_ELTS = NUM_ELTS_;
    // The size of element in bits.
    static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
    // The size of byte of a single register.
    static constexpr int BYTES_PER_REG = 4;
    // The size in bits.
    static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
    // The number of registers needed to store the fragment.
    static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
    // The size in bytes (as returned by sizeof(Fragment_base<>).
    static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
    // The alignment.
    static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The type of the elements.
    typename Data_type_,
    // The number of elements.
    int NUM_ELTS_,
    // The alignment if you want to force a value -- use 0 otherwise.
    int ALIGNMENT_ = 0,
    // The base class.
    typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {

    // The size of a load/store.
    static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);

    // Clear the fragment. Using PTX in that code seems to produce better SASS...
    inline __device__ void clear() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
        }
    }

    // Immutable access to a register.
    inline __device__ const uint32_t& reg(int ii) const {
        return this->regs_[ii];
    }

    // Mutable access to a register.
    inline __device__ uint32_t& reg(int ii) {
        return this->regs_[ii];
    }

    uint32_t regs_[Base_::NUM_REGS];

    // Immutable access to the elements.
    inline __device__ const Data_type_& elt(int ii) const {
        return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    inline __device__ Data_type_& elt(int ii) {
        return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
    }

    // Immutable access to the elements with a cast.
    template< typename Cast_type >
    inline __device__ const Cast_type& elt_as(int ii) const {
        return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    template< typename Cast_type >
    inline __device__ Cast_type& elt_as(int ii) {
        return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
    }

    // Add another fragment.
    inline __device__ void add(const Fragment &other) {
        // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
        // Also are we doing int addition or __half2 addition?
        #pragma unroll
        for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
            this->elt(ii) += other.elt(ii);
        }
    }

    // Multiply by another fragment.
    inline __device__ void hmul(const Fragment &other) {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
        }
    }

    inline __device__ void hrelu_() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hrelu2(this->reg(ii));
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Fragment_accumulator : public Fragment<float, 8> {

    // The base class.
    using Base = Fragment<float, 8>;

    // Add two fragments.
    template< typename Other_fragment_ >
    inline __device__ void add(const Other_fragment_ &other) {
        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
            this->elt(ii) = this->elt(ii) + other.elt(ii);
        }
    }

    inline __device__ void mul_(const float other) {
        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
            this->elt(ii) *= other;
        }
    }

    // Do the HMMA.
    template< typename Layout_a, typename Layout_b >
    inline __device__ void mma(const Fragment_a<Layout_a> &a,
                               const Fragment_b<Layout_b> &b) {
        asm volatile( \
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
            "    {%0, %1, %2, %3}, \n" \
            "    {%4, %5, %6, %7}, \n" \
            "    {%8, %9}, \n" \
            "    {%0, %1, %2, %3}; \n" \
                    : "+f"(  elt(0)), "+f"(  elt(1)), "+f"(  elt(2)), "+f"(  elt(3))
                    :  "r"(a.reg(0)),  "r"(a.reg(1)),  "r"(a.reg(2)),  "r"(a.reg(3))
                    ,  "r"(b.reg(0)),  "r"(b.reg(1)));
        asm volatile( \
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
            "    {%0, %1, %2, %3}, \n" \
            "    {%4, %5, %6, %7}, \n" \
            "    {%8, %9}, \n" \
            "    {%0, %1, %2, %3}; \n" \
                    : "+f"(  elt(4)), "+f"(  elt(5)), "+f"(  elt(6)), "+f"(  elt(7))
                    :  "r"(a.reg(0)),  "r"(a.reg(1)),  "r"(a.reg(2)),  "r"(a.reg(3))
                    ,  "r"(b.reg(2)),  "r"(b.reg(3)));
    }

};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
    #pragma unroll
    for( int mi = 0; mi < M; ++mi ) {
        #pragma unroll
        for( int ni = 0; ni < N; ++ni ) {
            frag[mi][ni].clear();
        }
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
  template< typename Acc, int M, int N >
  static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
    fmha::clear(acc);
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {

    #pragma unroll
    for( int mi = 0; mi < M; ++mi ) {
        #pragma unroll
        for( int ni = 0; ni < N; ++ni ) {
            acc[mi][ni].mma(a[mi], b[ni]);
        }
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The number of rows in the CTA tile.
    int M_,
    // The number of cols in the CTA tile.
    int N_,
    // The number of elements in the the K dimension of the GEMM loop.
    int K_,
    // The number of rows of warps.
    int WARPS_M_,
    // The number of cols of warps.
    int WARPS_N_,
    // The number of warps in the K dimension of the GEMM loop.
    int WARPS_K_>
struct Cta_tile_ {

    static constexpr int M = M_, N = N_, K = K_;
    // The number of warps.
    static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
    // The number of warps per CTA.
    static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
    // The number of threads per warp.
    static constexpr int THREADS_PER_WARP = 32;
    // The number of threads per CTA.
    static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile>
struct Hmma_tile {
    // The number of elements computed with a single warp-MMA.
    static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;

    // The number of elements computed with a single CTA-MMA.
    static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
        N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
        K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;

    // The number of MMAs needed to compute the GEMM.
    static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
        MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
        MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);

    // // The number of elements computed per warp.
    // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
    //     N_PER_WARP = MMAS_N * N_PER_MMA,
    //     K_PER_WARP = MMAS_K * K_PER_MMA;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;

constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
                                                   Cta_tile_::N,
                                                   Next_power_of_two<Cta_tile_::K>::VALUE,
                                                   Cta_tile_::WARPS_M,
                                                   Cta_tile_::WARPS_N,
                                                   Cta_tile_::WARPS_K>;

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha