gemm.h 16.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <fmha/utils.h>

Tri Dao's avatar
Tri Dao committed
32
33
34
35
36
37
38
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>

Tri Dao's avatar
Tri Dao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {

    // The data type.
    using Data_type = Data_type_;
    // default input type
    using Input_type_ = Data_type_;
    // Does it store the array of elements.
    static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
    // The number of elements.
    static constexpr int NUM_ELTS = NUM_ELTS_;
    // The size of element in bits.
    static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
    // The size of byte of a single register.
    static constexpr int BYTES_PER_REG = 4;
    // The size in bits.
    static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
    // The number of registers needed to store the fragment.
    static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
    // The size in bytes (as returned by sizeof(Fragment_base<>).
    static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
    // The alignment.
    static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The type of the elements.
    typename Data_type_,
    // The number of elements.
    int NUM_ELTS_,
    // The alignment if you want to force a value -- use 0 otherwise.
    int ALIGNMENT_ = 0,
    // The base class.
    typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {

    // The size of a load/store.
    static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);

    // Clear the fragment. Using PTX in that code seems to produce better SASS...
    inline __device__ void clear() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
        }
    }

    // Immutable access to a register.
    inline __device__ const uint32_t& reg(int ii) const {
        return this->regs_[ii];
    }

    // Mutable access to a register.
    inline __device__ uint32_t& reg(int ii) {
        return this->regs_[ii];
    }

    uint32_t regs_[Base_::NUM_REGS];

    // Immutable access to the elements.
    inline __device__ const Data_type_& elt(int ii) const {
        return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    inline __device__ Data_type_& elt(int ii) {
        return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
    }

    // Immutable access to the elements with a cast.
    template< typename Cast_type >
    inline __device__ const Cast_type& elt_as(int ii) const {
        return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    template< typename Cast_type >
    inline __device__ Cast_type& elt_as(int ii) {
        return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
    }

    // Add another fragment.
    inline __device__ void add(const Fragment &other) {
        // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
        // Also are we doing int addition or __half2 addition?
        #pragma unroll
        for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
            this->elt(ii) += other.elt(ii);
        }
    }

    // Multiply by another fragment.
    inline __device__ void hmul(const Fragment &other) {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
        }
    }

    inline __device__ void hrelu_() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hrelu2(this->reg(ii));
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Fragment_accumulator : public Fragment<float, 8> {

    // The base class.
    using Base = Fragment<float, 8>;

    // Add two fragments.
    template< typename Other_fragment_ >
    inline __device__ void add(const Other_fragment_ &other) {
        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
            this->elt(ii) = this->elt(ii) + other.elt(ii);
        }
    }

    inline __device__ void mul_(const float other) {
        for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
            this->elt(ii) *= other;
        }
    }

    // Do the HMMA.
    template< typename Layout_a, typename Layout_b >
    inline __device__ void mma(const Fragment_a<Layout_a> &a,
                               const Fragment_b<Layout_b> &b) {
        asm volatile( \
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
            "    {%0, %1, %2, %3}, \n" \
            "    {%4, %5, %6, %7}, \n" \
            "    {%8, %9}, \n" \
            "    {%0, %1, %2, %3}; \n" \
                    : "+f"(  elt(0)), "+f"(  elt(1)), "+f"(  elt(2)), "+f"(  elt(3))
                    :  "r"(a.reg(0)),  "r"(a.reg(1)),  "r"(a.reg(2)),  "r"(a.reg(3))
                    ,  "r"(b.reg(0)),  "r"(b.reg(1)));
        asm volatile( \
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
            "    {%0, %1, %2, %3}, \n" \
            "    {%4, %5, %6, %7}, \n" \
            "    {%8, %9}, \n" \
            "    {%0, %1, %2, %3}; \n" \
                    : "+f"(  elt(4)), "+f"(  elt(5)), "+f"(  elt(6)), "+f"(  elt(7))
                    :  "r"(a.reg(0)),  "r"(a.reg(1)),  "r"(a.reg(2)),  "r"(a.reg(3))
                    ,  "r"(b.reg(2)),  "r"(b.reg(3)));
    }

};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
    #pragma unroll
    for( int mi = 0; mi < M; ++mi ) {
        #pragma unroll
        for( int ni = 0; ni < N; ++ni ) {
            frag[mi][ni].clear();
        }
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
  template< typename Acc, int M, int N >
  static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
    fmha::clear(acc);
  }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {

    #pragma unroll
    for( int mi = 0; mi < M; ++mi ) {
        #pragma unroll
        for( int ni = 0; ni < N; ++ni ) {
            acc[mi][ni].mma(a[mi], b[ni]);
        }
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

Tri Dao's avatar
Tri Dao committed
257
258
259
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
    using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
Tri Dao's avatar
Tri Dao committed
260
#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
Tri Dao's avatar
Tri Dao committed
261
    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
Tri Dao's avatar
Tri Dao committed
262
263
264
265
266
267
268
#elif defined(__CUDA_ARCH__)  && __CUDA_ARCH__ >= 750
    using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
    using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
    // TD [2022-06-02] We don't support Volta (SM70) yet.
    assert(0);
#endif
Tri Dao's avatar
Tri Dao committed
269
270
271
272
273
274
275
276
277
    using Element = cutlass::half_t;
    using ElementC = float;
    using LayoutA = cutlass::layout::RowMajor;
    using LayoutB = cutlass::layout::ColumnMajor;

    using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
        Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
        cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;

Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
    constexpr int kIters = Shape::kK / InstructionShape::kK;
    // using FragmentA = typename WarpMma::FragmentA;
    // using FragmentB = typename WarpMma::FragmentB;
    using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
    using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
Tri Dao's avatar
Tri Dao committed
283
284
    using FragmentC = typename WarpMma::FragmentC;

Tri Dao's avatar
Tri Dao committed
285
286
287
288
289
290
291
292
293
294
295
296
297
    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
    //     printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
    //     printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
    //     printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
    //     printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
    //     printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
    //     printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
    // }

    // static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
    // static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
    static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
    static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
Tri Dao's avatar
Tri Dao committed
298
    static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
Tri Dao's avatar
Tri Dao committed
299
300
    // const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
    // const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
Tri Dao's avatar
Tri Dao committed
301
    FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
Tri Dao's avatar
Tri Dao committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    FragmentA a_cl[kIters][M];
    FragmentA b_cl[kIters][N];
    constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        #pragma unroll
        for (int mi = 0; mi < M; mi++) {
            uint32_t *a_ptr = a_cl[iter][mi].raw_data();
            #pragma unroll
            for (int ki = 0; ki < kRegs; ki++) {
                a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
            }
        }
    }
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        #pragma unroll
        for (int ni = 0; ni < N; ni++) {
            uint32_t *b_ptr = b_cl[iter][ni].raw_data();
            #pragma unroll
            for (int ki = 0; ki < kRegs; ki++) {
                // b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
                // TD [2022-06-02] For some reason the order for frag_b is different.
                b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
            }
        }
    }
Tri Dao's avatar
Tri Dao committed
329
330

    WarpMma mma_op;
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335
336
    // mma_op(c_cl, a_cl, b_cl, c_cl);
    #pragma unroll
    for (int iter = 0; iter < kIters; iter++) {
        mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
               reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
    }
Tri Dao's avatar
Tri Dao committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

    // The modified c_cl is not copied back into acc, idk why
    #pragma unroll
    for (int mi = 0; mi < M; mi++) {
        #pragma unroll
        for (int ni = 0; ni < N; ni++) {
            #pragma unroll
            for (int i =0; i < 8; i++) {
                acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
            }
        }
    }

}

////////////////////////////////////////////////////////////////////////////////////////////////////

Tri Dao's avatar
Tri Dao committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
template<
    // The number of rows in the CTA tile.
    int M_,
    // The number of cols in the CTA tile.
    int N_,
    // The number of elements in the the K dimension of the GEMM loop.
    int K_,
    // The number of rows of warps.
    int WARPS_M_,
    // The number of cols of warps.
    int WARPS_N_,
    // The number of warps in the K dimension of the GEMM loop.
    int WARPS_K_>
struct Cta_tile_ {

    static constexpr int M = M_, N = N_, K = K_;
    // The number of warps.
    static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
    // The number of warps per CTA.
    static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
    // The number of threads per warp.
    static constexpr int THREADS_PER_WARP = 32;
    // The number of threads per CTA.
    static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile>
struct Hmma_tile {
    // The number of elements computed with a single warp-MMA.
    static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;

    // The number of elements computed with a single CTA-MMA.
    static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
        N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
        K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;

    // The number of MMAs needed to compute the GEMM.
    static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
        MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
        MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);

    // // The number of elements computed per warp.
    // static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
    //     N_PER_WARP = MMAS_N * N_PER_MMA,
    //     K_PER_WARP = MMAS_K * K_PER_MMA;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;

constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
                                                   Cta_tile_::N,
                                                   Next_power_of_two<Cta_tile_::K>::VALUE,
                                                   Cta_tile_::WARPS_M,
                                                   Cta_tile_::WARPS_N,
                                                   Cta_tile_::WARPS_K>;

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha