gmem_tile.h 23.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

30
#include <cuda_fp16.h>
Tri Dao's avatar
Tri Dao committed
31
32
33
#include <cuda_bf16.h>

#include <fmha/utils.h>
34

Tri Dao's avatar
Tri Dao committed
35
36
namespace fmha {

37
38
39
40
41
42
43
44
45
46
47
48
// template <typename half2_t>
// inline __device__ void atomic_add_CAS(half2_t *address, const half2_t val) {
//     uint32_t *address_as_ui = (uint32_t *)address;
//     uint32_t old = *address_as_ui;
//     uint32_t assumed;
//     do {
//       assumed = old;
//       half2_t sum = __hadd2(val, reinterpret_cast<half2_t(&)>(old));
//       old = atomicCAS(address_as_ui, assumed, reinterpret_cast<uint32_t(&)>(sum));
//     } while (assumed != old);
// }

Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
55
56
57
58
////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile_,
    // The number of bits per element.
    int BITS_PER_ELEMENT,
    // The number of rows of Q, K or V loaded by this tile.
    int ROWS_,
    // The number of columns.
Tri Dao's avatar
Tri Dao committed
59
60
    int COLS,
    int BYTES_PER_LDGS_ = 16
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
>
struct Gmem_tile_qkv {

    using Cta_tile = Cta_tile_;

66
    static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
Tri Dao's avatar
Tri Dao committed
67
    // The size of each LDG.
Tri Dao's avatar
Tri Dao committed
68
    static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
75
76
77
78
79
80
81
    // The size of a row in bytes.
    static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;

    // The number of threads to load a "row" of the matrix.
    static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;

    static constexpr int ROWS = ROWS_;
    // The number of "rows" loaded per LDG.
    static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
    // The number of LDGs needed to load a chunk of the Q matrix.
    static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);

    // Ctor.
82
83
    template< typename BInfo >
    inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
Tri Dao's avatar
Tri Dao committed
84
                                    const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx, bool use_seqlen_q)
85
        : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
Tri Dao's avatar
Tri Dao committed
86
        , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
87
        , ptr(reinterpret_cast<char *>(ptr_))
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
93
94
95
96
97
98
99
100
        , tidx_(tidx) {

        // Compute the position in the sequence (within the CTA for the moment).
        int row = tidx / THREADS_PER_ROW;
        // Compute the position of the thread in the row.
        int col = tidx % THREADS_PER_ROW;

        // Store the row as we need it to disable the loads.
        // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
        // row_ = row;

        // The row offset in the batched GEMM. For each seq element, we store QKV in that order.
        // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
Tri Dao's avatar
Tri Dao committed
101
        uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
Tri Dao's avatar
Tri Dao committed
102
103
        // Add the block index.
        // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
104
        row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
Tri Dao's avatar
Tri Dao committed
105
106

        // Assemble the final pointer.
107
        ptr += row_offset + col * BYTES_PER_LDG;
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    }

    // Store data to shared memory.
    template< typename Smem_tile >
    inline __device__ void commit(Smem_tile &smem_tile) {
        smem_tile.store(fetch_);
    }

    inline __device__ void load() {
        int row_ = tidx_ / THREADS_PER_ROW;
        const void *ptrs[LDGS];
        uint32_t preds[LDGS];
        #pragma unroll
        for( int ii = 0; ii < LDGS; ++ii ) {
122
123
            // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
            ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
            fetch_[ii] = make_uint4(0, 0, 0, 0);
        }

        // not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
        Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
        #pragma unroll
        for( int ii = 0; ii < LDGS; ++ii ) {
            fct.load(ii, preds[ii]);
        }
    }

    // Store data to memory.
    inline __device__ void store(const uint4 (&data)[LDGS]) {
        int row_ = tidx_ / THREADS_PER_ROW;
        #pragma unroll
        for( int ii = 0; ii < LDGS; ++ii ) {
141
142
            // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
            char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
Tri Dao's avatar
Tri Dao committed
143
            if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
144
                fmha::stg(ptr_, data[ii]);
Tri Dao's avatar
Tri Dao committed
145
146
147
148
            }
        }
    }

Tri Dao's avatar
Tri Dao committed
149
150
151
152
153
154
155
156
157
158
159
160
    template <typename elem_type>
    inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
        int row_ = tidx_ / THREADS_PER_ROW;
        #pragma unroll
        for( int ii = 0; ii < LDGS; ++ii ) {
            using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
            // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
            elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
            if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
                #pragma unroll
                for (int jj = 0; jj < 4; ++jj) {
                    atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
161
                    // atomic_add_CAS(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
Tri Dao's avatar
Tri Dao committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                }
            }
        }
    }

    // Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
    inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
        static_assert(BYTES_PER_ELEMENT == 4);  // Only support fp32
        int row_ = tidx_ / THREADS_PER_ROW;
        #pragma unroll
        for( int ii = 0; ii < LDGS; ++ii ) {
            // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
            float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
            if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
                #pragma unroll
                for (int jj = 0; jj < 4; ++jj) {
                    const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
                    atomicAdd(ptr_ + jj * 2, data_f.x);
                    atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
                }
            }
        }
    }

186
187
188
    inline __device__ void move(const int steps = 1) {
        // ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
        ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
Tri Dao's avatar
Tri Dao committed
189
190
191
192
        actual_seqlen -= ROWS * steps;
    }

    // The stride between rows for the QKV matrice.
193
194
    // int64_t row_stride_in_bytes;
    const uint32_t row_stride_in_bytes;
Tri Dao's avatar
Tri Dao committed
195
    // The pointer.
196
    char *ptr;
Tri Dao's avatar
Tri Dao committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    // The fetch registers.
    uint4 fetch_[LDGS];
    // Keep track of the row the thread is processing as we move the tile.
    // int row_;
    const int tidx_;
    // The length of the sequence loaded by that memory tile.
    int actual_seqlen;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    typename Cta_tile,
    int BYTES_PER_ELEMENT = 2
>
struct Gmem_tile_o {

    static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);

    // The mma tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // The size of each element.
    // static constexpr int BYTES_PER_ELEMENT = 2;
    // The size of each STG.
    static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
    static constexpr int COLS = Cta_tile::N;
    // The size of a row in bytes.
    static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;

    // The number of threads to store a "row" of the matrix.
    static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
    // The number of "rows" stored per iteration of the loop. The output of 1 MMA.
    static constexpr int ROWS = Cta_tile::M;
    // The number of "rows" stored per iteration of the loop. The output of 1 MMA.
    static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
    // The number of outter loop for the stores.
    static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;

    // The number of "rows" stored per STG.
    static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
    // Do we have to guard against partial writes/reads.
    static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
    // The number of STGs needed to store a chunk of the Q matrix.
    static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
    // The number of STGs needed to store a chunk of the Q matrix in total.
    static constexpr int STGS = STGS_PER_LOOP * LOOPS;

    // Ctor.
    template<typename BInfo>
247
248
249
250
    // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
    inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
                                  const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
        : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
Tri Dao's avatar
Tri Dao committed
251
        , actual_seqlen_q(binfo.actual_seqlen_q)
Tri Dao's avatar
Tri Dao committed
252
253
254
255
256
257
258
259
260
261
262
263
        , ptr_(reinterpret_cast<char *>(ptr))
        , tidx_(tidx) {

        // Compute the position in the sequence (within the CTA for the moment).
        int row = tidx / THREADS_PER_ROW;
        // Compute the position of the thread in the row.
        int col = tidx % THREADS_PER_ROW;

        // Store the row as we need it to disable loads.
        // row_ = row;

        // The row offset in the batched GEMM.
264
        // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
Tri Dao's avatar
Tri Dao committed
265
        uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
266
        row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
Tri Dao's avatar
Tri Dao committed
267
268
269
270
271
272
273
274
275
276
        // Assemble the final pointer.
        ptr_ += row_offset + col * BYTES_PER_STG;

        // Is that thread active on the last STG?
        if( HAS_INCOMPLETE_STG ) {
            is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
        }
    }

    // Store data to global memory.
277
    template<typename elem_type=__half>
Tri Dao's avatar
Tri Dao committed
278
279
280
281
282
    inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
        int row_ = tidx_ / THREADS_PER_ROW;
        #pragma unroll
        for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
            int jj = mi * STGS_PER_LOOP + ii;
Tri Dao's avatar
Tri Dao committed
283
            if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) {
Tri Dao's avatar
Tri Dao committed
284
285
286
287
288
                break;
            }

            if (BYTES_PER_ELEMENT == 4) {
                if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
289
                    fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
Tri Dao's avatar
Tri Dao committed
290
291
292
293
294
295
                }
            } else if (BYTES_PER_ELEMENT == 2) {
                float x = reinterpret_cast<const float &>(src[ii].x);
                float y = reinterpret_cast<const float &>(src[ii].y);
                float z = reinterpret_cast<const float &>(src[ii].z);
                float w = reinterpret_cast<const float &>(src[ii].w);
296
                uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
Tri Dao's avatar
Tri Dao committed
297
                if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
298
                    fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
Tri Dao's avatar
Tri Dao committed
299
300
301
302
303
                }
            }
        }
    }

304
    // Load data from global memory.
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
    inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
        static_assert(BYTES_PER_ELEMENT == 4);
        int row_ = tidx_ / THREADS_PER_ROW;
        #pragma unroll
        for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
            int jj = mi * STGS_PER_LOOP + ii;
Tri Dao's avatar
Tri Dao committed
311
            if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) {
Tri Dao's avatar
Tri Dao committed
312
313
314
315
                break;
            }

            if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
316
                fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
Tri Dao's avatar
Tri Dao committed
317
318
319
320
            }
        }
    }

321
    inline __device__ void move(const int steps = 1) {
Tri Dao's avatar
Tri Dao committed
322
        // row_ += ROWS * steps;
323
324
        // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
        ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
Tri Dao's avatar
Tri Dao committed
325
        actual_seqlen_q -= ROWS * steps;
Tri Dao's avatar
Tri Dao committed
326
327
328
    }

    // The stride between rows for the QKV matrice.
329
330
    // int64_t row_stride_in_bytes;
    const uint32_t row_stride_in_bytes;
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335
    // The pointer.
    char *ptr_;
    // Is the thread active for the last STG?
    int is_active_for_last_stg_;
    // The length of the sequence loaded by that memory tile.
Tri Dao's avatar
Tri Dao committed
336
    int actual_seqlen_q;
Tri Dao's avatar
Tri Dao committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    const int tidx_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {

    // The mma tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // Each STG stores 8 elements.
    static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
    // The number of MMAs in the M dimension.
    static constexpr int MMAS_M = Mma_tile::MMAS_M;
    // The number of MMAs in the N dimension.
    static constexpr int MMAS_N = Mma_tile::MMAS_N;
    // The number of rows computed per MMA per thread block.
    static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
    // The number of cols computed per MMA per thread block.
    static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
    // The number of threads per block.
    static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
    // The size of each row in bytes. I.e. how many bytes are stored per STG.
    static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
    // The distance between elements stored per loop (in bytes).
    static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;

    // The type of elements stored per STG.
    using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;

    // Ctor.
    template<typename Params>
    inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx) 
        : ptr_(static_cast<char *>(ptr)) {

        // The block index.
        // size_t bidx = bidb * params.h + bidh;
        uint32_t bidx = bidb * params.h + bidh;

        // The distance between two blocks (in bytes).
Tri Dao's avatar
Tri Dao committed
378
379
        // const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
        const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
Tri Dao's avatar
Tri Dao committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        // Set store location for each thread at the beginning of the loop
        ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
    }

    // Store to global memory.
    inline __device__ void store(const Type &data, const int mi, const int ni) {
        // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
        uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
        fmha::stg(ptr_ + offset, data);
    }

    // Load from global memory.
    inline __device__ void load(Type &data, const int mi, const int ni) {
        // size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
        uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
        fmha::ldg(data, ptr_ + offset);
    }

    // Move to the next tile.
399
    inline __device__ void move(const int steps = 1) {
Tri Dao's avatar
Tri Dao committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
        ptr_ += LOOP_STRIDE_BYTES * steps;
    }

    // The pointer in global memory.
    char *ptr_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {

    // The number of mmas in the vertical dimension.
    static constexpr int M = Base::MMAS_M;
    // The number of mmas in the horizontal dimension.
    static constexpr int N = Base::MMAS_N;
    // The type of the vectors stored by each STG.
    using Type = typename Base::Type;

    // Ctor.
    template< typename Params, typename Block_info >
    inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx) 
        : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
    }

    // Store to global memory.
    template<typename Mask, typename Fragment>
    inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                uint4 dst;
                dst.x = frag[ni][mi].reg(0);
                dst.y = frag[ni][mi].reg(2);
                dst.z = frag[ni][mi].reg(1);
                dst.w = frag[ni][mi].reg(3);
                if( mask.any_valid(mi, ni) ) {
                    Base::store(dst, mi, ni);
                }
            }
        }
    }

    // Load from global memory.
    template<typename Mask>
    inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                regs[mi][ni] = make_uint4(0, 0, 0, 0);
                if( mask.any_valid(mi, ni) ) {
                    Base::load(regs[mi][ni], mi, ni);
                }
            }
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile
>
struct Gmem_summary_stats {

    // The Mma tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // The number of MMAs in M/N dimensions.
    static constexpr int MMAS_M = Mma_tile::MMAS_M;

    // The size of each element.
    static constexpr int BYTES_PER_ELEMENT = 4;
    static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
    static constexpr int ROWS = Cta_tile::M;

    // Ctor.
    template<typename Params>
    inline __device__ Gmem_summary_stats(void *ptr, const Params &params, const int tidx)
        : ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {

        // The block index for the batch.
485
        const int bidb = blockIdx.x;
Tri Dao's avatar
Tri Dao committed
486
        // The block index for the head.
487
        const int bidh = blockIdx.y;
Tri Dao's avatar
Tri Dao committed
488
489
490
491
492
493
494
495
496
        // The block index.
        // size_t bidx = bidb * params.h + bidh;
        uint32_t bidx = bidb * params.h + bidh;

        // Extract the position in the warp.
        int warp = tidx / Cta_tile::THREADS_PER_WARP;
        int lane = tidx % Cta_tile::THREADS_PER_WARP;

        // The distance between two blocks (in bytes).
Tri Dao's avatar
Tri Dao committed
497
498
        // size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
        uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
Tri Dao's avatar
Tri Dao committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579

        // Set store location for each thread at the beginning of the loop
        ptr_row_ = ptr_ + bidx * block_stride_bytes;
        ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
    }

    // Store data to global memory.
    inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
        int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
        int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
        if ((warp == 0) && (lane % 4 == 0)) {
            #pragma unroll
            for (int mi = 0; mi < MMAS_M; ++mi) {
                // TODO: Not sure if it's right for MMAS_M > 1
                fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
                fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
            }
        }
    }

    // Store data to global memory.
    inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
        #pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
            // TODO: Not sure if it's right for MMAS_M > 1
            fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
        }
    }

    // Load from global memory.
    inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
        #pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
            // TODO: Not sure if it's right for MMAS_M > 1
            fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
            fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
        }
    }

    // Load from global memory.
    inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
        char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
        #pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
            // TODO: Not sure if it's right for MMAS_M > 1
            fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
            fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
        }
    }

    // Store data to global memory.
    template <int N>
    inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
        #pragma unroll
        for (int ni = 0; ni < N; ++ni) {
            fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
        }
    }

    // Move the pointer to the next location.
    inline __device__ void move() {
        ptr_ += ROWS * BYTES_PER_ELEMENT;
        ptr_row_ += ROWS * BYTES_PER_ELEMENT;
    }

    // Move the pointer to the next location.
    inline __device__ void move(const int steps) {
        ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
        ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
    }

    // The pointer.
    char *ptr_;
    char *ptr_row_;
    const int tidx_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha