smem_tile.h 72.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include "utils.h"
#include <fmha/utils.h>
#include <fmha/gemm.h>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The description of the tile computed by this CTA.
    typename Cta_tile,
    // The number of rows in the 2D shared memory buffer.
    int M_,
    // The number of cols.
    int N_,
    // The size in bits of each element.
    int BITS_PER_ELEMENT_,
    // The number of bytes per STS.
    int BYTES_PER_STS_ = 16,
    // The number of buffers. (Used in multistage and double buffer cases.)
    int BUFFERS_PER_TILE_ = 1,
    // Do we enable the fast path for LDS.128 and friends.
    int ENABLE_LDS_FAST_PATH_ = 0,
    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
    int ROWS_PER_XOR_PATTERN_ = 8,
    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
    int COLS_PER_XOR_PATTERN_ = 1,
    // Use or not predicates
    bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {

    // The size in bits of each element.
    enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
    // The size in bytes of a single STS.
    enum { BYTES_PER_STS = BYTES_PER_STS_ };
    // The number of elements per STS.
    enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
    // To support arbitrary N, we pad some values to a power-of-2.
    enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
    // The number of bytes per row without packing of rows.
    enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
    // The number of bytes per row -- we want at least 128B per row.
    enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
    // The number of rows in shared memory (two rows may be packed into a single one).
    enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };

    // The number of threads per row.
    enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
    // The number of threads per row.
    enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };

    // The number of STS per row.
    enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
    // It must be at least one.
    static_assert(STS_PER_ROW >= 1, "");
    // The number of rows written with a single STS.
    enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
    // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
    static_assert(ROWS_PER_STS >= 1, "");
    // The number of STS needed to store all rows.
    enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
    // The number of STS in total.
    enum { STS = STS_PER_COL * STS_PER_ROW };

95
96
97
98
99
    // TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
    // we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
    static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS;
    static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA;

Tri Dao's avatar
Tri Dao committed
100
    // The size of one buffer in bytes in shared memory.
101
102
    // enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
    enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS };
Tri Dao's avatar
Tri Dao committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    // The number of buffers.
    enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
    // The size in bytes of total buffers.
    enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
    // The boundary for smem_read_offset and smem_write_offset increment.
    enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };

    // Do we enable the LDS.128 fast path?
    enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
    static_assert(ENABLE_LDS_FAST_PATH == 0);
    // The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
    enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
    // The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
    enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
    // Use or not predicates
    enum { USE_PREDICATES = USE_PREDICATES_ };

    // The type of elements that are stored in shared memory by each thread.
    using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;

    // Ctor.
    inline __device__ Smem_tile_without_skews(void *smem, int tidx)
125
        : smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) {
Tri Dao's avatar
Tri Dao committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        // The row written by a thread. See doc/mma_smem_layout.xlsx.
        int smem_write_row = tidx / THREADS_PER_ROW;

        // The XOR pattern.
        int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
        // Compute the column and apply the XOR pattern.
        int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;

        // The offset.
        this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;

        // TODO: Why not merge it with the read offset?
        // this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
        // this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
    }

    // Compute the store pointers.
    template< int N >
    inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
        #pragma unroll
        for( int ii = 0; ii < N; ++ii ) {
            // Decompose the STS into row/col.
            int row = ii / STS_PER_ROW;
            int col = ii % STS_PER_ROW;

            // Assemble the offset.
            int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;

            // Take the column into account.
            if( STS_PER_ROW > 1 ) {
                offset += col*THREADS_PER_ROW*BYTES_PER_STS;
            }

            // Apply the XOR pattern if needed.
            if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
                const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
                offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
            }

            // Assemble the final pointer :)
            // ptrs[ii] = smem_ + offset + smem_write_buffer_;
            // smem_write_buffer_ is already merged with smem_write_offset_
            ptrs[ii] = smem_ + offset;
        }
    }

    inline __device__ void debug_reset() {
        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
        for( int row = 0; row < ROWS; ++row ) {
            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
                if( threadIdx.x == 0 ) {
                    uint32_t val = 0x0;
                    sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
                }
            }
        }
        }
    }

    // Print the content of the tile (only for debug ;)).
    inline __device__ void debug_print() const {
        for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
        for( int row = 0; row < ROWS; ++row ) {
            for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
                if( threadIdx.x == 0 ) {
                    uint32_t val;
                    lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
                    printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
                        blockIdx.x,
                        blockIdx.y,
                        blockIdx.z,
                        smem_,
                        buffer,
                        row,
                        col,
                        val);
                }
            }
        }
        }
    }

    // Move the read offset to next buffer.
    inline __device__ void move_to_next_read_buffer() {
        // if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
        //     this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
        // } else if( BUFFERS_PER_TILE > 1 ) {
        //     this->smem_read_buffer_ += BYTES_PER_BUFFER;
        // }
        if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
            this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
        } else if( BUFFERS_PER_TILE > 1 ) {
            this->smem_read_offset_ += BYTES_PER_BUFFER;
        }
    }

    // Move the read offset to next buffer. TODO: Remove this member function!!!
    inline __device__ void move_next_read_buffer() {
        this->move_to_next_read_buffer();
    }

    // Move the read offset to next N buffer (circular-buffer).
    inline __device__ void move_to_next_read_buffer(int N) {
        if( BUFFERS_PER_TILE > 1 ) {
            // this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
            // this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
            this->smem_read_offset_ += N * BYTES_PER_BUFFER;
            this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
        }
    }

    // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
    inline __device__ void move_next_read_buffer(int N) {
        this->move_to_next_read_buffer(N);
    }

    // Move the write offset to next buffer.
    inline __device__ void move_to_next_write_buffer() {
        // if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
        //     this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
        // } else if( BUFFERS_PER_TILE > 1 ) {
        //     this->smem_write_buffer_ += BYTES_PER_BUFFER;
        // }
        if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
            this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
        } else if( BUFFERS_PER_TILE > 1 ) {
            this->smem_write_offset_ += BYTES_PER_BUFFER;
        }
    }

    // Move the write offset to next buffer. TODO: Remove that member function!
    inline __device__ void move_next_write_buffer() {
        this->move_to_next_write_buffer();
    }

    // Move the read offset.
    inline __device__ void move_read_offset(int delta) {
        this->smem_read_offset_ += delta;
    }

    // Move the write offset.
    inline __device__ void move_write_offset(int delta) {
        this->smem_write_offset_ += delta;
    }

    // Store to the tile in shared memory.
    template< int N >
    inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
        uint32_t smem_ptrs[N];
        this->compute_store_pointers(smem_ptrs);
277
278
279
280
        // Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
        if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) {
            sts(smem_ptrs, data);
        }
Tri Dao's avatar
Tri Dao committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    }

    // Store to the tile in shared memory.
    template< int N, int M >
    inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
        uint32_t smem_ptrs[N];
        this->compute_store_pointers(smem_ptrs);
        sts(smem_ptrs, data, preds);
    }

    // Store to the tile in shared memory.
    template< int N >
    inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
        this->store(data, preds);
    }

    // Store to the tile in shared memory.
    template< int N >
    inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
        uint32_t tmp[1] = { preds };
        this->store(gmem_ptrs, tmp);
    }

    // The shared memory pointer.
    const uint32_t smem_;
    // The read offset. Reserve 4 offsets if needed.
    int smem_read_offset_;
    // The write offset.
    int smem_write_offset_;
    // The buffer base offset for read.
    // int smem_read_buffer_;
    // The buffer base offset for write.
    // int smem_write_buffer_;
314
    const int tidx_;
Tri Dao's avatar
Tri Dao committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The layout of the tile.
    typename Layout,
    // The size of the STS.
    int BYTES_PER_STS = 16,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE = 1,
    // Use or not predicates
    bool USE_PREDICATES = true
>
struct Smem_tile_a {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
    // The potential mask.
    enum { HALF = MMAS_K_WITH_PADDING / 2 };
    // The remainder.
    enum { MOD = MMAS_K % HALF };
    // The final value.
    enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
    enum { VALUE = 0 };
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
    enum { VALUE = MMAS_K - 1 };
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int N >
struct Rows_per_xor_pattern_a {
    // The size in bits.
    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
    // The number of rows.
    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE,
    // How many rows to use for the XOR pattern to avoid bank conflicts?
    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
                                                               Cta_tile::M,
                                                               Cta_tile::K,
                                                               fmha::BITS_PER_ELEMENT_A,
                                                               BYTES_PER_STS,
                                                               BUFFERS_PER_TILE,
                                                               0,
                                                               ROWS_PER_XOR_PATTERN_,
                                                               1> {
    // The MMA tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    // The base class.
    using Base = Smem_tile_without_skews<Cta_tile,
                                         Cta_tile::M,
                                         Cta_tile::K,
                                         fmha::BITS_PER_ELEMENT_A,
                                         BYTES_PER_STS,
                                         BUFFERS_PER_TILE,
                                         0,
                                         ROWS_PER_XOR_PATTERN_,
                                         1>;
    // The fragment.
    using Fragment = Fragment_a<Row>;

    // When we use padding to reach a power of two, special care has to be taken.
    using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
    // The number of MMAs.
    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;

    // The size of a single LDS in bytes.
    enum { BYTES_PER_LDS = 16 };

    // Ctor.
    inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {

        // For documentation on the layout, see doc/mma_smem_layout.xlsx.

        // The number of warps.
        const int WARPS_M = Cta_tile::WARPS_M;
        const int WARPS_N = Cta_tile::WARPS_N;
        const int WARPS_K = Cta_tile::WARPS_K;

        static_assert(WARPS_M == 1);
        static_assert(WARPS_N == 4 || WARPS_N == 8);
        static_assert(WARPS_K == 1);
        static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);

        // The row and column read by the thread.
        int smem_read_row  = (tidx & 0x0f);
        constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
        int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
        smem_read_col ^= (tidx & 0x10) / 16;

        // The shared memory offset.
        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
    }

    // Rewind smem_read_offset for last LDS phase in main loop.
    inline __device__ void reverse_smem_read_offset(int ki = 0) {
        // Undo the pointer increment for the next ni.
        // Should match the load function below for ki = 0.
        if( Mma_tile_with_padding::MMAS_K >=  2 ) {
            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
        }
    }

    // Load from shared memory.
    inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
        #pragma unroll
        for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
            int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;

            // Load using LDSM.M88.4.
            uint4 tmp;
            // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
            ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);

            // Store the value into the fragment.
            a[mi].reg(0) = tmp.x;
            a[mi].reg(1) = tmp.y;
            a[mi].reg(2) = tmp.z;
            a[mi].reg(3) = tmp.w;
        }

        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
        static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {
            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {
            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {
            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {
            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;
        }
    }

    // Reset the read offset.
    inline __device__ void reset_read_offset() {
        // The number of MMAs in the K dimension.
        enum { MMAS_K = Mma_tile::MMAS_K };
        // The number of MMAs in the K dimension when we include padding.
        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
        // Assemble the mask.
        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };

        // Reset the read offset.
        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
    }

};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
    : public Smem_tile_row_a<Cta_tile,
                                    BYTES_PER_STS,
                                    BUFFERS_PER_TILE> {
    // The base class.
    using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;

    // Ctor.
    inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The layout of the tile.
    typename Layout,
    // The size of the STS.
    int BYTES_PER_STS = 16,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE = 1,
    // Use or not predicates
    bool USE_PREDICATES = true
>
struct Smem_tile_b {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int N >
struct Rows_per_xor_pattern_b {
    // The size in bits.
    enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
    // The number of rows.
    enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE,
    // How many rows to use for the XOR pattern to avoid bank conflicts?
    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
                                                           Cta_tile::N,
                                                           Cta_tile::K,
                                                           fmha::BITS_PER_ELEMENT_B,
                                                           BYTES_PER_STS,
                                                           BUFFERS_PER_TILE,
                                                           0,
                                                           ROWS_PER_XOR_PATTERN_,
                                                           1> {
    // The MMA tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    // The base class.
    using Base = Smem_tile_without_skews<Cta_tile,
                                         Cta_tile::N,
                                         Cta_tile::K,
                                         fmha::BITS_PER_ELEMENT_B,
                                         BYTES_PER_STS,
                                         BUFFERS_PER_TILE,
                                         0,
                                         ROWS_PER_XOR_PATTERN_,
                                         1>;
    // The fragment.
    using Fragment = Fragment_b< Col>;

    // When we use padding to reach a power of two, special care has to be taken.
    using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
    // The number of MMAs.
    using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;

    // The size of a single LDS in bytes.
    enum { BYTES_PER_LDS = 16 };

    // The number of STS per thread
    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
    // The number of STS per thread must be at least 1.
    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };

    // Ctor.
    inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {

        // For documentation on the layout, see doc/mma_smem_layout.xlsx.

        // The number of warps.
        const int WARPS_M = Cta_tile::WARPS_M;
        const int WARPS_N = Cta_tile::WARPS_N;
        const int WARPS_K = Cta_tile::WARPS_K;
        static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
        static_assert(WARPS_M == 1);
        static_assert(WARPS_N == 4 || WARPS_N == 8);
        static_assert(WARPS_K == 1);

        // The masks to select the warps.
        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;

        // The divisor for the warps.
        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;

        // The row and column read by the thread.
        int smem_read_row  = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
                             (tidx & 0x07) +
                             (tidx & 0x10) / 2;
        constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
        int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
        smem_read_col ^= (tidx & 0x08) / 8;
        // The shared memory offset.
        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
    }

    // Rewind smem_read_offset for last LDS phase in main loop.
    inline __device__ void reverse_smem_read_offset(int ki = 0) {
        // Undo the pointer increment for the next ni.
        // Should match the load function below for ki = 0.
        if( Mma_tile_with_padding::MMAS_K >=  2 ) {
            this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
        }
    }

    // Load from shared memory.
    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
            int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;

            // Load using LDSM.M88.4.
            uint4 tmp;
            // ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
            ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);

            // Store the value into the fragment.
            b[ni].reg(0) = tmp.x;
            b[ni].reg(1) = tmp.y;
            b[ni].reg(2) = tmp.z;
            b[ni].reg(3) = tmp.w;
        }

        // Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
        static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
        if(        Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
            this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki %  8 ==  7 ) {
            this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  8 && ki %  4 ==  3 ) {
            this->smem_read_offset_ ^=  7 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  4 && ki %  2 ==  1 ) {
            this->smem_read_offset_ ^=  3 * BYTES_PER_LDS * 2;
        } else if( Mma_tile_with_padding::MMAS_K >=  2 ) {
            this->smem_read_offset_ ^=  1 * BYTES_PER_LDS * 2;
        }
    }

    // Reset the read offset.
    inline __device__ void reset_read_offset() {
        // The number of MMAs in the K dimension.
        enum { MMAS_K = Mma_tile::MMAS_K };
        // The number of MMAs in the K dimension when we include padding.
        enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
        // Assemble the mask.
        enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };

        // Reset the read offset.
        this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
    : public Smem_tile_col_b<Cta_tile,
                             BYTES_PER_STS,
                             BUFFERS_PER_TILE> {

    // The base class.
    using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;

    // Ctor.
    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<  int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};

////////////////////////////////////////////////////////////////////////////////////////////////////


template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE,
    // How many rows to use for the XOR pattern to avoid bank conflicts?
    int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
    // How many cols to use for the XOR pattern to avoid bank conflicts?
    int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
                                                               Cta_tile::K,
                                                               Cta_tile::N,
                                                               fmha::BITS_PER_ELEMENT_B,
                                                               BYTES_PER_STS,
                                                               BUFFERS_PER_TILE,
                                                               0,
                                                               ROWS_PER_XOR_PATTERN_,
                                                               COLS_PER_XOR_PATTERN_> {

    // The MMA tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    // The base class.
    using Base = Smem_tile_without_skews<Cta_tile,
                                         Cta_tile::K,
                                         Cta_tile::N,
                                         fmha::BITS_PER_ELEMENT_B,
                                         BYTES_PER_STS,
                                         BUFFERS_PER_TILE,
                                         0,
                                         ROWS_PER_XOR_PATTERN_,
                                         COLS_PER_XOR_PATTERN_>;
    // The fragment.
    using Fragment = Fragment_b<Row>;

    // Can we use LDSM? No if the data type is 32-bit large.
    enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
    // The size of a single LDS in bytes.
    enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
    // The number of elements per LDS.
    enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };

    // The number of STS per thread
    enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
    // The number of STS per thread must be at least 1.
    enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };

    // Ctor.
    inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {

        // The number of warps.
        const int WARPS_M = Cta_tile::WARPS_M;
        const int WARPS_N = Cta_tile::WARPS_N;
        const int WARPS_K = Cta_tile::WARPS_K;
        static_assert(WARPS_K == 1);
        static_assert(WARPS_M == 4 || WARPS_M == 8);
        static_assert(WARPS_N == 1);

        // The masks to select the warps.
        const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
        const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;

        // The divisor for the warps.
        const int WARP_DIV_N = WARPS_M *       1 * Cta_tile::THREADS_PER_WARP;
        const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;


        static_assert(USE_LDSMT);
        static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);

        // The row/col read by the thread.
        int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
                            (tidx & 0x07) + (tidx & 0x08);
        constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
        int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
        smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;

        // The shared memory offset.
        this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;

        // Fill zeroes for group conv
    }

    // Rewind smem_read_offset for last LDS phase in main loop.
    inline __device__ void reverse_smem_read_offset(int ki = 0) {
        // The size of each element in bits.
        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
        // The size in bytes of the data needed to compute an MMA per CTA.
        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;

        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Undo the pointer increment for the next ni.
            // Should match the load function below for ki = 0.
            if( BYTES_PER_MMA_PER_CTA >= 128 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            }
        }

        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
                Mma_tile::MMAS_N % 2 == 1 ) {
            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
        }
    }

    // Load from shared memory.
    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
        // The size of each element in bits.
        const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
        // The size in bytes of the data needed to compute an MMA per CTA.
        const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;

Tri Dao's avatar
Tri Dao committed
850
        // uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
Tri Dao's avatar
Tri Dao committed
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Prepare the offset.
            int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING;
                if ( BYTES_PER_MMA_PER_CTA == 32 ) {
                    offset += this->smem_read_offset_;
                } else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
                    offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
                } else {
                    offset += this->smem_read_offset_ + (ni  ) * BYTES_PER_MMA_PER_CTA;
                }

            // Load the data using LDSM.MT88.2.
            // uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
            uint32_t ptr = this->smem_ + offset;
            uint4 tmp;
            if( USE_LDSMT ) {
                ldsmt(tmp, ptr);
            } else {
                lds(tmp.x, (ptr     ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.y, (ptr     ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
                lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
            }

Tri Dao's avatar
Tri Dao committed
876
877
878
            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
            //     printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
            // }
Tri Dao's avatar
Tri Dao committed
879
880
881
882
883
884
885
886
887
888
889
890
891
            // Store those values in the fragment.
            b[ni].reg(0) = tmp.x;
            b[ni].reg(1) = tmp.y;
            b[ni].reg(2) = tmp.z;
            b[ni].reg(3) = tmp.w;

            // Move the pointer for the next ni. I expect the compiler to not recompute those.
            if( BYTES_PER_MMA_PER_CTA >= 128 ) {
                // Nothing to do!
            } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
                this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
            } else if( BYTES_PER_MMA_PER_CTA == 64 ) {
                // Nothing to do!
Tri Dao's avatar
Tri Dao committed
892
893
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
Tri Dao's avatar
Tri Dao committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            }
        }

        // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
        if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
                Mma_tile::MMAS_N % 2 == 1 ) {
            this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    // The dimensions of the tile computed by the CTA.
    typename Cta_tile,
    // The size of the STS.
    int BYTES_PER_STS,
    // The number of buffers per tile.
    int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
    : public Smem_tile_row_b<Cta_tile,
                             BYTES_PER_STS,
                             BUFFERS_PER_TILE> {

    // The base class.
    using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;

    // Ctor.
    inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1> {

    // The base class.
    using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1>;
    // The MMA tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    // The fragment.
    using Fragment = Fragment_b< fmha::Col>;

    // The size of a single LDS in bytes.
    enum { BYTES_PER_LDS = 16 };

    // Ctor.
    inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {

        // The row/col read by the thread.
        int read_row, read_col;

        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));

        read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
        constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
        read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
        read_col ^= (tidx & 0x10) / 16;

        // The shared memory offset.
        this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS;
    }

    // Load from shared memory.
    inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
            // Jump by 16 * #warps row.
            int row = ki * 16 * Cta_tile::WARPS_K;

            // Load the data using LDSM.MT88.2.
            uint4 tmp;
            fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING);
            b[ni].reg(0) = tmp.x;
            b[ni].reg(1) = tmp.y;
            b[ni].reg(2) = tmp.z;
            b[ni].reg(3) = tmp.w;

            // Move the pointer for the next ni. I expect the compiler to not recompute those.
            if( Mma_tile::MMAS_N == 1 ) {
                // noop
            } else if( Mma_tile::MMAS_N == 2 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
            } else if( Mma_tile::MMAS_N == 4 ) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
            } else if (Mma_tile::MMAS_N == 8) {
                this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
            } else {
                assert(false);  // Not implemented!
            }
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Cta_tile>
struct Smem_tile_o {

    // The MMA tile.
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    // The accumulators.
    using Accumulator = fmha::Fragment_accumulator;
    // The accumulators.
    using Data_type = typename Accumulator::Data_type;

    // The size of each element.
    static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type);
    // The size of each STS.
    static constexpr int BYTES_PER_STS = 8;
    // The size of each row in shared memory.
    static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT;

    // The size of each LDS.
    static constexpr int BYTES_PER_LDS = 16;
    static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS;

    // The number of rows.
    static constexpr int ROWS = Cta_tile::M;
    // The number of "rows" to process per loop iteration (in the "epilogue").
    static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
    // The number of outer loops.
    static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
    // Make sure it matches our expectations.
    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");

    // The number of rows loaded per LDS.
    static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
    // Do we have to guard against partial writes/reads.
    static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0;
    // The total number of LDS per loop.
    static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS);

    // The amount of shared memory.
    static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW;

    // The write pointer.
    uint32_t smem_write_, smem_read_;
    // Is the thread active for the last LDS of the series?
    int is_active_for_last_lds_;

    // static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
    static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");

    // Ctor.
    inline __device__ Smem_tile_o(void *smem, int tidx) {

        // Get a 32-bit value for the shared memory address.
        uint32_t smem_ = __nvvm_get_smem_pointer(smem);

        static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
        static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128);

        int write_row = (tidx & 0x1c) / 4;

        const int lane = tidx % 32;
        const int warp = tidx / 32;

        constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT;
        constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
        int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;

Tri Dao's avatar
Tri Dao committed
1063
1064
1065
1066
1067
1068
1069
1070
        // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("write_row = %d, write_col = %d\n", write_row, write_col);
        // }

        // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
        //     printf("threadIdx.x = %d\n", threadIdx.x);
        // }

Tri Dao's avatar
Tri Dao committed
1071
1072
1073
1074
1075
1076
1077
1078
        // Assemble the write pointer.
        smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;

        // The element read by each thread.
        int read_row = tidx / THREADS_PER_ROW;
        int read_col = tidx % THREADS_PER_ROW;

        // Take the XOR pattern into account for the column.
Tri Dao's avatar
Tri Dao committed
1079
1080
        read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
        // read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
Tri Dao's avatar
Tri Dao committed
1081

Tri Dao's avatar
Tri Dao committed
1082
1083
1084
1085
1086
1087
        // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
        //     printf("read_row = %d, read_col = %d\n", read_row, read_col);
        // }
        // if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
        //     printf("threadIdx.x = %d\n", threadIdx.x);
        // }
Tri Dao's avatar
Tri Dao committed
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        // Assemble the read pointer.
        this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;

        // Is that thread active on the last LDS?
        if( HAS_INCOMPLETE_LDS ) {
            this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
        }
    }

    // Load the output fragments.
    template <bool zero_init=true>
    inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
        #pragma unroll
        for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {

            // Load the elements before the reduction (split-K).
            uint4 tmp[Cta_tile::WARPS_K];
            #pragma unroll
            for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
                int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
Tri Dao's avatar
Tri Dao committed
1108
                uint32_t smem_read = this->smem_read_ + imm;
Tri Dao's avatar
Tri Dao committed
1109
1110
                // TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
                if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
Tri Dao's avatar
Tri Dao committed
1111
1112
1113
1114
1115
                    smem_read ^= 8 * BYTES_PER_LDS;
                }
                // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
                //     printf("imm diff = %d\n", smem_read - this->smem_read_);
                // }
Tri Dao's avatar
Tri Dao committed
1116
                if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
Tri Dao's avatar
Tri Dao committed
1117
1118
                    // fmha::lds(tmp[jj], this->smem_read_ + imm);
                    fmha::lds(tmp[jj], smem_read);
Tri Dao's avatar
Tri Dao committed
1119
1120
1121
1122
1123
                }
            }

            // Perform the reduction.
            out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
Tri Dao's avatar
Tri Dao committed
1124
1125
1126
            // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
            // }
Tri Dao's avatar
Tri Dao committed
1127
1128
1129
            #pragma unroll
            for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
                out[ii] = fmha::fadd4(out[ii], tmp[jj]);
Tri Dao's avatar
Tri Dao committed
1130
1131
1132
                // if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
                //     printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
                // }
Tri Dao's avatar
Tri Dao committed
1133
1134
1135
1136
1137
1138
1139
            }
        }
    }

    // Store the accumulators.
    template <int M, int N>
    inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
Tri Dao's avatar
Tri Dao committed
1140
        // uint32_t smem_write_og = this->smem_write_;
Tri Dao's avatar
Tri Dao committed
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
        static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {

            // The number of MMAs that are stored per loop iteration.
            static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;

            // Store 1st column of the different MMAs.
            #pragma unroll
            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
                // Precompute the immediates to jump between rows.
                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
                uint2 tmp0, tmp1;
                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);

                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);

                // Store.
                fmha::sts(this->smem_write_ + row_0, tmp0);
                fmha::sts(this->smem_write_ + row_1, tmp1);
            }
Tri Dao's avatar
Tri Dao committed
1165
1166
1167
1168
1169
1170
1171
1172
1173
            // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
            // }

            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     uint4 read_tmp;
            //     fmha::lds(read_tmp, this->smem_read_);
            //     printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
            // }
Tri Dao's avatar
Tri Dao committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
            // Swizzle the write pointer using a XOR of 16B.
            this->smem_write_ ^= 32;

            // Store 2nd column of the different MMAs.
            #pragma unroll
            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
                // Precompute the immediates to jump between rows.
                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;

                uint2 tmp0, tmp1;
                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);

                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
                // Store.
                fmha::sts(this->smem_write_ + row_0, tmp0);
                fmha::sts(this->smem_write_ + row_1, tmp1);
            }

Tri Dao's avatar
Tri Dao committed
1195
1196
1197
1198
            // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
            // }

Tri Dao's avatar
Tri Dao committed
1199
            // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
Tri Dao's avatar
Tri Dao committed
1200
1201
1202
1203
1204
1205
1206
            static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
            if(        Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
                this->smem_write_ ^= 15 * 32;
            } else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
                this->smem_write_ ^= 7 * 32;
            } else if( Mma_tile::MMAS_N >= 2 ) {
                this->smem_write_ ^= 3 * 32;
1207
1208
            } else {
                this->smem_write_ ^= 3 * 32;
Tri Dao's avatar
Tri Dao committed
1209
1210
1211
1212
1213
1214
1215
            }
            // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     uint4 read_tmp;
            //     fmha::lds(read_tmp, this->smem_read_);
            //     printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
            // }
Tri Dao's avatar
Tri Dao committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
        }
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile>
struct Smem_tile_mma {

    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    using Fragment = fmha::Fragment_a<fmha::Col>;

    enum { COLS = Cta_tile::N };
    enum { BYTES_PER_ELT = 2 };
    enum { BYTES_PER_STS = 4 };
    enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };  // TODO
    enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };

    enum { WARPS_M = Cta_tile::WARPS_M };
    enum { WARPS_N = Cta_tile::WARPS_N };
    enum { WARPS_K = Cta_tile::WARPS_K };

    static_assert(WARPS_K == 1);
    inline __device__ Smem_tile_mma(char *smem, int tidx) {
        uint32_t smem_ = __nvvm_get_smem_pointer(smem);

        int write_col, write_row;
Tri Dao's avatar
Tri Dao committed
1243
        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1);
Tri Dao's avatar
Tri Dao committed
1244
1245
1246
        if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
            write_row = (tidx & 0x1c) / 4;
            write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
Tri Dao's avatar
Tri Dao committed
1247
            write_col ^= (write_row & 0x07) * 4;
Tri Dao's avatar
Tri Dao committed
1248
1249
1250
        } else {
            write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
            write_col = (tidx & 0x03);
Tri Dao's avatar
Tri Dao committed
1251
1252
            // write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
            write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4;
Tri Dao's avatar
Tri Dao committed
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        }

        // write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
        smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
    }

    template<int M, int N>
    inline __device__ void store(const uint4 (&regs)[M][N]) {
        static_assert(COLS == Cta_tile::N);
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
                // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
                // offset ^= 4 * BYTES_PER_STS;
                // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
                // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
                // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
                fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
                offset ^= 4 * BYTES_PER_STS;
                fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
                fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
            }
        }
    }

    template<typename Fragment, int M, int N>
    inline __device__ void store(const Fragment (&frag)[N][M]) {
        static_assert(COLS == Cta_tile::N);
        uint4 regs[M][N];
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                // Need to transpose ref(1) and reg(2) here since when we load it we transpose again.
                regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2),
                                          frag[ni][mi].reg(1), frag[ni][mi].reg(3));
            }
        }
        this->store(regs);
    }

    // uint32_t smem_;
    // uint32_t write_offset_;
    uint32_t smem_write_;
};

template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
    enum { BYTES_PER_LDS = 16 };
    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
    enum { WARPS_M = Base::WARPS_M };
    enum { WARPS_N = Base::WARPS_N };
    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
    using Fragment = typename Base::Fragment;
    inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {

        uint32_t smem_ = __nvvm_get_smem_pointer(smem);
        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
        int read_row, read_col;
        read_row = (tidx & 0x0f);
        read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;

Tri Dao's avatar
Tri Dao committed
1321
1322
        // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
        read_col ^= (read_row & 0x07);
Tri Dao's avatar
Tri Dao committed
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
        // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
        smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
    }

    template<int M, int N>
    inline __device__ void load(Fragment (&frag)[M][N]) {
        static_assert(Base::COLS == Cta_tile::N);
        for( int mi = 0; mi < M; mi++ ) {
            for( int ni = 0; ni < N; ni++ ) {
                // size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint4 dst;
                // fmha::ldsmt(dst, this->smem_ + offset);
                // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
                fmha::ldsmt(dst, offset);
                frag[mi][ni].reg(0) = dst.x;
                frag[mi][ni].reg(1) = dst.z;  // Fragment A regs col major!
                frag[mi][ni].reg(2) = dst.y;
                frag[mi][ni].reg(3) = dst.w;
            }
        }
    }

    // uint32_t read_offset_;
    uint32_t smem_read_;
};

template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
    enum { BYTES_PER_LDS = 16 };
    enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
    enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
    enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
    static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
    enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
    enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
    static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
    enum { WARPS_M = Base::WARPS_M };
    enum { WARPS_N = Base::WARPS_N };
    static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);

    using Acc = fmha::Fragment_accumulator;

    inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
        uint32_t smem_ = __nvvm_get_smem_pointer(smem);
        const int read_row = tidx / THREADS_PER_ROW;
        int read_col = tidx % THREADS_PER_ROW;
Tri Dao's avatar
Tri Dao committed
1370
1371
1372
        // read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
        static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256);
        read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07))));
Tri Dao's avatar
Tri Dao committed
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
        // read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
        smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
    }

    inline __device__ void load(uint4 (&data)[NUM_LDS]) {
        for( int ii = 0; ii < NUM_LDS; ii++ ) {
            // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
            // fmha::lds(data[ii], this->smem_ + offset);
            // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
            uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
            fmha::lds(data[ii], offset);
        }
    }

1387
    template<typename elem_type=__half, int M, int N>
Tri Dao's avatar
Tri Dao committed
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
    inline __device__ void store(const Acc (&acc)[M][N]){
        #pragma unroll
        for( int mi = 0; mi < M; mi++ ) {
            #pragma unroll
            for( int ni = 0; ni < N; ni++ ) {
                // 1st row - 4 elements per row.
                float tmp00 = acc[mi][ni].elt(0);
                float tmp01 = acc[mi][ni].elt(1);
                float tmp02 = acc[mi][ni].elt(4);
                float tmp03 = acc[mi][ni].elt(5);
                // 2nd row - 4 elements per row.
                float tmp10 = acc[mi][ni].elt(2);
                float tmp11 = acc[mi][ni].elt(3);
                float tmp12 = acc[mi][ni].elt(6);
                float tmp13 = acc[mi][ni].elt(7);

1404
1405
1406
1407
                uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
                uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
                uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
                uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
Tri Dao's avatar
Tri Dao committed
1408
1409
1410
1411
1412
1413
1414
1415
1416

                // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
                // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
                // offset ^= 4 * Base::BYTES_PER_STS;
                // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
                // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
                // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
Tri Dao's avatar
Tri Dao committed
1417
1418
1419
                // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
                //     printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
                // }
Tri Dao's avatar
Tri Dao committed
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
                fmha::sts(offset + 0 * BYTES_PER_ROW, x);
                fmha::sts(offset + 8 * BYTES_PER_ROW, z);
                offset ^= 4 * Base::BYTES_PER_STS;
                fmha::sts(offset + 0 * BYTES_PER_ROW, y);
                fmha::sts(offset + 8 * BYTES_PER_ROW, w);
            }
        }
    }

    template<int M, int N>
    inline __device__ void store(const uint4 (&regs)[M][N]) {
        for( int mi = 0; mi < M; mi++ ) {
            for( int ni = 0; ni < N; ni++ ) {
                // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
                offset ^= 4 * Base::BYTES_PER_STS;
                fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
                fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
            }
        }
    }

    // uint32_t read_offset_;
    uint32_t smem_read_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Cta_tile>
struct Smem_tile_transpose {

    using Mma_tile = fmha::Hmma_tile<Cta_tile>;
    using Fragment_write = fmha::Fragment_b<fmha::Col>;
    using Fragment_read = fmha::Fragment_b<fmha::Col>;

    enum { COLS = Cta_tile::N };
    enum { BYTES_PER_ELT = 2 };
    enum { BYTES_PER_STS = 4 };
    enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };  // TODO
    enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };

    enum { BYTES_PER_LDS = 16 };

    enum { WARPS_M = Cta_tile::WARPS_M };
    enum { WARPS_N = Cta_tile::WARPS_N };
    enum { WARPS_K = Cta_tile::WARPS_K };

    static_assert(WARPS_K == 1);
    static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));

    inline __device__ Smem_tile_transpose(char *smem, int tidx) {
        smem_ = __nvvm_get_smem_pointer(smem);
        // uint32_t smem_ = __nvvm_get_smem_pointer(smem);

        int write_col, write_row;
        static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
        if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
            write_row = (tidx & 0x1c) / 4;
            write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
        } else {
            write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
            write_col = (tidx & 0x03);
        }
        write_col ^= (write_row & 0x07) * 4;

        write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
        // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;

        int read_row, read_col;
        read_row = (tidx & 0x0f);
        read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;

        read_col ^= (read_row & 0x07);
        read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
        // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
    }

    template<int M, int N>
    inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
            offset ^= 4 * BYTES_PER_STS;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
        }
    }

    template<int N>
    inline __device__ void load(Fragment_read (&frag_r)[N]) {
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint4 dst;
            fmha::ldsmt(dst, this->smem_ + offset);
            frag_r[ni].reg(0) = dst.x;
            frag_r[ni].reg(1) = dst.y;  // Fragment B regs col major!
            frag_r[ni].reg(2) = dst.z;
            frag_r[ni].reg(3) = dst.w;
        }
    }

    template<int M, int N>
    inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
        static_assert(COLS == Cta_tile::N);
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
            offset ^= 4 * BYTES_PER_STS;
            fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
            fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
        }
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
            uint4 dst;
            fmha::ldsmt(dst, this->smem_ + offset);
            frag_r[ni].reg(0) = dst.x;
            frag_r[ni].reg(1) = dst.y;  // Fragment B regs col major!
            frag_r[ni].reg(2) = dst.z;
            frag_r[ni].reg(3) = dst.w;
        }
    }

    uint32_t smem_;
    uint32_t write_offset_;
    uint32_t read_offset_;
    // uint32_t smem_write_;
    // uint32_t smem_read_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<
    typename Gmem_tile,
    // The number of buffers. (Used in multistage and double buffer cases.)
    int BUFFERS_PER_TILE_ = 1
>
struct Smem_tile_dp_sum {

    using Cta_tile = typename Gmem_tile::Cta_tile;
    using Mma_tile = fmha::Hmma_tile<Cta_tile>;

    // The size of each element.
    static constexpr int BYTES_PER_ELEMENT = 4;
    static constexpr int ROWS = Gmem_tile::ROWS;
    static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW;
    static constexpr int MMAS_M = Mma_tile::MMAS_M;

    static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG;
    static constexpr int LDGS = Gmem_tile::LDGS;

    static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA;

    // The size of one buffer in bytes in shared memory.
    static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT;
    // The number of buffers.
    static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_;
    // The size in bytes of total buffers.
    static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE;
    // The boundary for smem_read_offset and smem_write_offset increment.
    static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS;

    inline __device__ Smem_tile_dp_sum(float *smem, const int tidx)
        : smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) {
    }

    // Move the read offset to next buffer.
    inline __device__ void move_to_next_read_buffer() {
        if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
            this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
        } else if( BUFFERS_PER_TILE > 1 ) {
            this->smem_read_buffer_ += ROWS;
        }
    }

    // Move the write offset to next buffer.
    inline __device__ void move_to_next_write_buffer() {
        if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
            this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
        } else if( BUFFERS_PER_TILE > 1 ) {
            this->smem_write_buffer_ += ROWS;
        }
    }

    inline __device__ void store(const float (&sum)[LDGS]) {
        if (tidx_ % THREADS_PER_ROW == 0) {
            int row = tidx_ / THREADS_PER_ROW;
            #pragma unroll
            for (int i = 0; i < LDGS; ++i) {
                if (row + i * ROWS_PER_LDG < ROWS) {
                    smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i];
                }
            }
        }
    }

    inline __device__ void store(const float sum, const int buffer_idx) {
        float *smem_write = smem_ + buffer_idx * ROWS;
        int row = tidx_ / THREADS_PER_ROW;
        if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) {
            smem_write[row] = sum;
        }
    }

    inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) {
        float *smem_write = smem_ + buffer_idx * ROWS;
        if (tidx_ % THREADS_PER_ROW == 0) {
            int row = tidx_ / THREADS_PER_ROW;
            #pragma unroll
            for (int i = 0; i < LDGS; ++i) {
                if (row + i * ROWS_PER_LDG < ROWS) {
                    smem_write[row + i * ROWS_PER_LDG] = sum[i];
                }
            }
        }
    }

1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
    inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) {
        float *smem_write = smem_;
        // Extract the position in the warp.
        int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
        int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
        int row = lane / 4;
        #pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
            smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
            smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
        }
    }

Tri Dao's avatar
Tri Dao committed
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
    inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
        float *smem_write = smem_ + buffer_idx * ROWS;
        // Extract the position in the warp.
        int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
        int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
        int row = lane / 4;
        #pragma unroll
        for (int mi = 0; mi < MMAS_M; ++mi) {
            smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
            smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
        }
    }

    template<int N>
    inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            sum[ni] = smem_read_buffer_[row[ni]];
        }
    }

    template<int N>
    inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) {
        float *smem_read = smem_ + buffer_idx * ROWS;
        #pragma unroll
        for( int ni = 0; ni < N; ni++ ) {
            sum[ni] = smem_read[row[ni]];
        }
    }

    static inline __device__ float reduce_warp(float sum) {
        fmha::SumOp<float> sum_op;
        return fmha::Allreduce<THREADS_PER_ROW>::run(sum, sum_op);
    }

    const int tidx_;
    float * const smem_;
    float *smem_read_buffer_;
    float *smem_write_buffer_;
};

}  // namespace fmha