mainloop_fwd_sm90_tma_gmma_ws.hpp 29.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/******************************************************************************
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 ******************************************************************************/

#pragma once

#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "cutlass/pipeline/pipeline.hpp"

#include "cute/tensor.hpp"

#include "cutlass/gemm/collective/collective_builder.hpp"

Tri Dao's avatar
Tri Dao committed
17
#include "named_barrier.hpp"
Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
#include "utils.h"

namespace flash {

using namespace cute;

24
template <typename Ktraits, bool Is_causal, typename Seqlen_traits>
Tri Dao's avatar
Tri Dao committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
struct CollectiveMainloopFwd {

    using Element = typename Ktraits::Element;
    using TileShape_MNK = typename Ktraits::TileShape_MNK;
    using ClusterShape = typename Ktraits::ClusterShape_MNK;

    static constexpr int kStages = Ktraits::kStages;
    static constexpr int kHeadDim = Ktraits::kHeadDim;

    using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
    using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));

    using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
    using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));

    using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
    using SmemLayoutK =
        decltype(tile_to_shape(SmemLayoutAtomK{},
                 make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
ganeshcolfax's avatar
ganeshcolfax committed
46
47
48
49
50
51
52
53

    using SmemLayoutAtomVFp8 = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
    using SmemLayoutVFp8 =
        decltype(tile_to_shape(SmemLayoutAtomVFp8{},
                 make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));

    using SmemLayoutVFp16 = SmemLayoutK;
Tri Dao's avatar
Tri Dao committed
54
    // Note this is the transpose in terms of the view, not in terms of memory.
ganeshcolfax's avatar
ganeshcolfax committed
55
56
    using SmemLayoutVtFp16 =
        decltype(cute::composition(SmemLayoutVFp16{},
Tri Dao's avatar
Tri Dao committed
57
                                   make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
ganeshcolfax's avatar
ganeshcolfax committed
58
59
60
61
62
63
64
65
66
67
68
69
                                               make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutVFp16{}(_, _, _0{}))>{}))));

    using SmemLayoutV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVFp16{}));
    using SmemLayoutVt = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(SmemLayoutVFp8{}, SmemLayoutVtFp16{}));

    // Dummy S layout for getting the shape for GEMM-II.
    using SmemLayoutAtomS = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
    using SmemLayoutS =
        decltype(tile_to_shape(SmemLayoutAtomS{},
                 make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{}))));

Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
    // using SmemLayoutVt =
    //     decltype(tile_to_shape(SmemLayoutAtomVt{},
    //                            make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{}),
    //                            Step<_2, _1, _3>{}));  // This gives correct results, without Step it's wrong
    // using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::MN, Element,
    //     decltype(cute::get<2>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>());
    // using SmemLayoutVt =
    //     decltype(tile_to_shape(SmemLayoutAtomVt{},
    //              make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int<kStages>{})));
    // using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom<Element>;
    // using SmemLayoutVTMA =
    //     decltype(tile_to_shape(SmemLayoutAtomVTMA{},
    //                            make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));

    using TMA_Q = decltype(make_tma_copy(
        GmemTiledCopyQ{},
87
88
89
90
91
        make_tensor(
            make_gmem_ptr(static_cast<Element const*>(nullptr)), 
            repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), 
            typename Seqlen_traits::StrideT{}
        ),
Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
        SmemLayoutQ{},
        select<0, 2>(TileShape_MNK{}),
        _1{}));  // no mcast for Q

    using TMA_KV = decltype(make_tma_copy(
        GmemTiledCopyKV{},
98
99
100
101
102
        make_tensor(
            make_gmem_ptr(static_cast<Element const*>(nullptr)), 
            repeat_like(typename Seqlen_traits::StrideT{}, int32_t(0)), 
            typename Seqlen_traits::StrideT{}
        ),
Tri Dao's avatar
Tri Dao committed
103
104
105
        take<0, 2>(SmemLayoutK{}),
        select<1, 2>(TileShape_MNK{}),
        size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
ganeshcolfax's avatar
ganeshcolfax committed
106
107
108
109
110
111
112
113
114
115
116
117
118
				   //
     using TileShapeVFP8 = decltype(make_shape(cute::get<2>(TileShape_MNK{}), cute::get<1>(TileShape_MNK{})));
     using TileShapeVFP16 = decltype(make_shape(cute::get<1>(TileShape_MNK{}), cute::get<2>(TileShape_MNK{})));
     using TileShapeV = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TileShapeVFP8{}, TileShapeVFP16{}));
     using TMA_VFP8 = decltype(make_tma_copy(
        GmemTiledCopyKV{},
        make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}),
        take<0, 2>(SmemLayoutV{}),
        TileShapeV{},
        size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
				   
    using TMA_V = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(TMA_VFP8{}, TMA_KV{}));

Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
129
130

    static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
    using MainloopPipeline = typename Ktraits::MainloopPipeline;
    using PipelineParams = typename MainloopPipeline::Params;
    using PipelineState = typename MainloopPipeline::PipelineState;

    // Set the bytes transferred in this TMA transaction (may involve multiple issues)
    static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
    static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);

    static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;

ganeshcolfax's avatar
ganeshcolfax committed
131

Tri Dao's avatar
Tri Dao committed
132
133
134
    // Host side kernel arguments
    struct Arguments {
        Element const* ptr_Q;
135
        typename Seqlen_traits::LayoutT layout_Q;
Tri Dao's avatar
Tri Dao committed
136
        Element const* ptr_K;
137
        typename Seqlen_traits::LayoutT layout_K;
Tri Dao's avatar
Tri Dao committed
138
        Element const* ptr_V;
139
        typename Seqlen_traits::LayoutT layout_V;
Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
        float const softmax_scale_log2;
    };

    // Device side kernel params
    struct Params {
145
146
147
        typename Seqlen_traits::LayoutT layout_Q;
        typename Seqlen_traits::LayoutT layout_K;
        typename Seqlen_traits::LayoutT layout_V;
Tri Dao's avatar
Tri Dao committed
148
        cutlass::FastDivmod qhead_per_khead_divmod;
Tri Dao's avatar
Tri Dao committed
149
        TMA_Q tma_load_Q;
ganeshcolfax's avatar
ganeshcolfax committed
150
151
        TMA_KV tma_load_K;
	TMA_V tma_load_V;
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
        float const softmax_scale_log2;
    };


    static Params
    to_underlying_arguments(Arguments const& args) {
158
        Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
Tri Dao's avatar
Tri Dao committed
159
160
161
162
163
164
        TMA_Q tma_load_Q = make_tma_copy(
            GmemTiledCopyQ{},
            mQ,
            SmemLayoutQ{},
            select<0, 2>(TileShape_MNK{}),
            _1{}); // no mcast for Q
165
        Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
Tri Dao's avatar
Tri Dao committed
166
167
168
169
170
171
        TMA_KV tma_load_K = make_tma_copy(
            GmemTiledCopyKV{},
            mK,
            SmemLayoutK{}(_, _, _0{}),
            select<1, 2>(TileShape_MNK{}),
            size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
ganeshcolfax's avatar
ganeshcolfax committed
172
173
174
175
176
	auto gmemLayoutVFp16 = args.shape_K;
        auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
        auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
        Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), gmemLayoutV, args.layout_V.stride());
        TMA_V tma_load_V = make_tma_copy(
Tri Dao's avatar
Tri Dao committed
177
178
179
            GmemTiledCopyKV{},
            mV,
            SmemLayoutV{}(_, _, _0{}),
ganeshcolfax's avatar
ganeshcolfax committed
180
            cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(select<2, 1>(TileShape_MNK{}), select<1, 2>(TileShape_MNK{})),
Tri Dao's avatar
Tri Dao committed
181
            size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
182
183
        return {args.layout_Q, args.layout_K, args.layout_V,
                cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
Tri Dao's avatar
Tri Dao committed
184
185
                tma_load_Q, tma_load_K, tma_load_V,
                args.softmax_scale_log2};
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
192
193
194
195
196
    }

    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
    CUTLASS_DEVICE
    static void prefetch_tma_descriptors(Params const& mainloop_params) {
        cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
        cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
        cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
    }

    CUTLASS_DEVICE
197
198
199
200
201
    int get_n_block_max(
          Params const& mainloop_params, int m_block, 
          const Seqlen_traits& seqlen_traits_q,
          const Seqlen_traits& seqlen_traits_k
        ) {
Tri Dao's avatar
Tri Dao committed
202
203
        static constexpr int kBlockM = get<0>(TileShape_MNK{});
        static constexpr int kBlockN = get<1>(TileShape_MNK{});
204
205
        int const seqlen_q = seqlen_traits_q.actual_seq_len;
        int const seqlen_k = seqlen_traits_k.actual_seq_len;
Tri Dao's avatar
Tri Dao committed
206
207
208
209
210
211
212
213
        int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
        if constexpr (Is_causal) {
            n_block_max = std::min(n_block_max,
                                   cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN));
        }
        return n_block_max;
    }

Tri Dao's avatar
Tri Dao committed
214
    template <typename Scheduler, typename SharedStorage>
Tri Dao's avatar
Tri Dao committed
215
    CUTLASS_DEVICE void
Tri Dao's avatar
Tri Dao committed
216
    load(Params const& mainloop_params,
Tri Dao's avatar
Tri Dao committed
217
218
219
220
221
         MainloopPipeline pipeline_k,
         MainloopPipeline pipeline_v,
         PipelineState& smem_pipe_write_k,
         PipelineState& smem_pipe_write_v,
         SharedStorage &shared_storage,
Tri Dao's avatar
Tri Dao committed
222
223
224
225
         Scheduler& scheduler,
         typename Scheduler::Params const& scheduler_params,
         typename Scheduler::WorkTileInfo& work_tile_info,
         cute::tuple<int32_t, int32_t, int32_t> block_coord,
226
227
228
         int work_idx,
         const Seqlen_traits& seqlen_traits_q,
         const Seqlen_traits& seqlen_traits_k
Tri Dao's avatar
Tri Dao committed
229
230
231
232
233
234
         ) {

        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
        Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
        Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});

235
236
        Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
        Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
ganeshcolfax's avatar
ganeshcolfax committed
237
238
239
240
	auto gmemLayoutVFp16 = mainloop_params.shape_K;
        auto gmemLayoutVFp8 = select<1, 0, 2, 3>(gmemLayoutVFp16);
        auto gmemLayoutV = cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(gmemLayoutVFp8, gmemLayoutVFp16);
        Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(gmemLayoutV);
Tri Dao's avatar
Tri Dao committed
241

Tri Dao's avatar
Tri Dao committed
242
243
244
        auto [m_block, bidh, bidb] = block_coord;
        int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);

Tri Dao's avatar
Tri Dao committed
245
246
247
248
        // Prepare the TMA loads
        uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
        constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
        uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
ganeshcolfax's avatar
ganeshcolfax committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{}));  // (M, K)
        Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}));  // (N, K, _)
        Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), TileShapeV{}, cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(make_coord(_0{}, _), make_coord(_, _0{})));  // (N, K, _)

#if 0
	if (threadIdx.x == 0 && blockIdx.x == 0) {
	   print ("\n");
           print (gV);
	   print ("\n");
	   print (gK);
	   print ("\n");
	   print ("\n");
           print (sV);
	   print ("\n");
	   print (sK);
	   print ("\n");
           print (gmemLayoutVFp8);
	   print ("\n");
           print (gmemLayoutVFp16);
	}

        // Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
        //     mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block);  // (M, K)
        // Tensor gK = seqlen_traits_k.get_local_tile_tensor(
        //     mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
        // Tensor gV = seqlen_traits_k.get_local_tile_tensor(
        //     mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb);  // (N, K, _)
Tri Dao's avatar
Tri Dao committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
        Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
        auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},
                                          group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));  // (TMA), (TMA)
        auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout<ClusterShape>{},
                                          group_modes<0, 2>(sK), group_modes<0, 2>(gK));  // (TMA, k), (TMA, PIPE)
        auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout<ClusterShape>{},
                                          group_modes<0, 2>(sV), group_modes<0, 2>(gV));  // (TMA, k), (TMA, PIPE)

        uint16_t mcast_mask_kv = 0;
        if constexpr (cute::is_same_v<GmemTiledCopyKV, SM90_TMA_LOAD_MULTICAST>) {
            auto block_layout = Layout<ClusterShape>{}; // (m,n) -> block_id
            for (int m = 0; m < size<0>(block_layout); ++m) {
                mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{}));
            }
        }

295
        int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
305
306
        int n_block = n_block_max - 1;

        int lane_predicate = cute::elect_one_sync();
        if (lane_predicate) {
            pipeline_k.producer_acquire(smem_pipe_write_k);
            copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
                tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
            ++smem_pipe_write_k;
        }

        // Wait for the MMA warpgroups to say that smem_q is ready
Tri Dao's avatar
Tri Dao committed
307
        cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
Tri Dao's avatar
Tri Dao committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

        if (lane_predicate) {
            shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
            copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
        }

        // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem
        // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
        // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O.
        shared_storage.barrier_O.wait((work_idx + 1) % 2);

        if (lane_predicate) {
            // CUTLASS_PRAGMA_NO_UNROLL
            #pragma unroll 2
            for (; n_block > 0; --n_block) {
                pipeline_k.producer_acquire(smem_pipe_write_k);
                copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
                    tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
                ++smem_pipe_write_k;
                pipeline_v.producer_acquire(smem_pipe_write_v);
                copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
                    tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
                ++smem_pipe_write_v;
            }
        }
Tri Dao's avatar
Tri Dao committed
333
        scheduler.prefetch_next_work(scheduler_params, work_tile_info);
Tri Dao's avatar
Tri Dao committed
334
335
336
337
338
339
        if (lane_predicate) {
            pipeline_v.producer_acquire(smem_pipe_write_v);
            copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
                tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
            ++smem_pipe_write_v;
        }
Tri Dao's avatar
Tri Dao committed
340
        scheduler.broadcast_next_work(work_tile_info);
Tri Dao's avatar
Tri Dao committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    }

    /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
    CUTLASS_DEVICE void
    load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
              PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) {
        int lane_predicate = cute::elect_one_sync();
        // Issue the epilogue waits
        if (lane_predicate) {
          /* This helps avoid early exit of blocks in Cluster
          * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used
          * then would just be acquired since the phase was still inverted from make_producer_start_state
          */
          pipeline_k.producer_tail(smem_pipe_write_k);
          pipeline_v.producer_tail(smem_pipe_write_v);
        }
    }

    CUTLASS_DEVICE void
Tri Dao's avatar
Tri Dao committed
360
    warp_scheduler_barrier_sync() {
Tri Dao's avatar
Tri Dao committed
361
        if constexpr (UseSchedulerBarrier) {
Tri Dao's avatar
Tri Dao committed
362
            cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/);
Tri Dao's avatar
Tri Dao committed
363
364
365
366
        }
    }

    CUTLASS_DEVICE void
Tri Dao's avatar
Tri Dao committed
367
    warp_scheduler_barrier_arrive() {
Tri Dao's avatar
Tri Dao committed
368
369
370
        if constexpr (!UseSchedulerBarrier) { return; }
        static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
        if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
Tri Dao's avatar
Tri Dao committed
371
            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/);
Tri Dao's avatar
Tri Dao committed
372
        } else {
Tri Dao's avatar
Tri Dao committed
373
374
            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3)  /*id*/);
            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3)  /*id*/);
Tri Dao's avatar
Tri Dao committed
375
376
377
378
379
380
        }
    }

    CUTLASS_DEVICE void
    mma_init() {
        // Tell producer (warp 0) that smem_q is ready
Tri Dao's avatar
Tri Dao committed
381
        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
Tri Dao's avatar
Tri Dao committed
382
383
384
        if constexpr (!UseSchedulerBarrier) { return; }
        static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
        if (cutlass::canonical_warp_group_idx() > 1) {
Tri Dao's avatar
Tri Dao committed
385
            cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
Tri Dao's avatar
Tri Dao committed
386
387
388
        }
        if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
            if (cutlass::canonical_warp_group_idx() > 2) {
Tri Dao's avatar
Tri Dao committed
389
                cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast<int>(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/);
Tri Dao's avatar
Tri Dao committed
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
            }
        }

    }

    template <typename SharedStorage, typename FrgTensorO, typename Softmax>
    CUTLASS_DEVICE void
    mma(Params const& mainloop_params,
        MainloopPipeline pipeline_k,
        MainloopPipeline pipeline_v,
        PipelineState& smem_pipe_read_k,
        PipelineState& smem_pipe_read_v,
        FrgTensorO& tOrO,
        Softmax& softmax,
        int n_block_count,
        int thread_idx,
        int work_idx,
        int m_block,
408
409
410
        SharedStorage& shared_storage,
        const Seqlen_traits& seqlen_traits_q,
        const Seqlen_traits& seqlen_traits_k
Tri Dao's avatar
Tri Dao committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        ) {
        static_assert(is_rmem<FrgTensorO>::value, "O tensor must be rmem resident.");

        static constexpr int kBlockM = get<0>(TileShape_MNK{});
        static constexpr int kBlockN = get<1>(TileShape_MNK{});

        Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
        Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
        Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});

        typename Ktraits::TiledMma0 tiled_mma0;
        typename Ktraits::TiledMma1 tiled_mma1;
        auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
        auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);

        // Allocate "fragments/descriptors" for first matmul.
        Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
        Tensor tSrK = threadMma0.partition_fragment_B(sK);
        // Allocate "fragments/descriptors" for second matmul.
        // Note: S becomes P.
        Tensor tOrV = threadMma1.partition_fragment_B(sVt);

ganeshcolfax's avatar
ganeshcolfax committed
433
434
435
436
437
438
439
        // Dummy sS to just get the shape correctly for GEMM-II.
        Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{});
	Tensor tOrS = threadMma1.partition_fragment_A(sS);
        Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
        ReorgCFp8toAFp8 reg2reg;
        auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS);

Tri Dao's avatar
Tri Dao committed
440
441
442
443
444
445
        auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
            auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
            pipeline.consumer_wait(smem_pipe_read, barrier_token);
        };

        tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
446
447
        int const seqlen_q = seqlen_traits_q.actual_seq_len;
        int const seqlen_k = seqlen_traits_k.actual_seq_len;
Tri Dao's avatar
Tri Dao committed
448
449
450
451
452
453
        int n_block = n_block_count - 1;

        cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(work_idx % 2));
        if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); }

        consumer_wait(pipeline_k, smem_pipe_read_k);
Tri Dao's avatar
Tri Dao committed
454
        warp_scheduler_barrier_sync();
Tri Dao's avatar
Tri Dao committed
455
        flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
Tri Dao's avatar
Tri Dao committed
456
        warp_scheduler_barrier_arrive();
Tri Dao's avatar
Tri Dao committed
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        if (work_idx != 0) {
            int lane_predicate = cute::elect_one_sync();
            if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) {
                tma_store_wait<0>();
                #pragma unroll
                for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
                    shared_storage.barrier_O.arrive(cta_id, lane_predicate);
                }
            }
        }
        warpgroup_wait<0>();
        pipeline_k.consumer_release(smem_pipe_read_k);
        ++smem_pipe_read_k;

        auto col_limit_causal = [&](int row, int n_block) {
            return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
        };
        {
            Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
            Tensor tScS = threadMma0.partition_C(cS);
            #pragma unroll
            for (int i = 0; i < size(tSrS); ++i) {
                if constexpr (!Is_causal) {  // Just masking based on col
                    if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
                } else {  // mask based on both row and col
                    // using std::min is faster than doing col >= limit0 or col >= limit1
                    // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
                    // right hand side can be negative and might be converted to a very large unsigned integer.
                    if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
                                                        col_limit_causal(int(get<0>(tScS(i))), n_block))) {
                        tSrS(i) = -INFINITY;
                    }
                }
            }
        }

        softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
ganeshcolfax's avatar
ganeshcolfax committed
494
495
496
497
498
        auto tSrSPrec = convert_type<Element>(tSrS);
        if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
          reg2reg(tSrSPrec);
        }
        Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout);
Tri Dao's avatar
Tri Dao committed
499
500
501
502
503
504
505
506
507
        Tensor scores_scale = make_fragment_like(softmax.row_max);
        clear(scores_scale);

        constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
        // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
        #pragma unroll
        for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) {
            Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
            consumer_wait(pipeline_k, smem_pipe_read_k);
Tri Dao's avatar
Tri Dao committed
508
            warp_scheduler_barrier_sync();
Tri Dao's avatar
Tri Dao committed
509
510
511
512
            flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
            if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); }
            consumer_wait(pipeline_v, smem_pipe_read_v);
            flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
Tri Dao's avatar
Tri Dao committed
513
            warp_scheduler_barrier_arrive();
Tri Dao's avatar
Tri Dao committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
            warpgroup_wait<1>();
            pipeline_k.consumer_release(smem_pipe_read_k);  // release K
            Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
            Tensor tScS = threadMma0.partition_C(cS);
            #pragma unroll
            for (int i = 0; i < size(tSrS); ++i) {
                if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) {
                    tSrS(i) = -INFINITY;
                }
            }
            cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
            softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
            warpgroup_wait<0>();
            pipeline_v.consumer_release(smem_pipe_read_v);  // release V
            ++smem_pipe_read_k;
            ++smem_pipe_read_v;
ganeshcolfax's avatar
ganeshcolfax committed
530
531
532
533
534
            auto tSrSPrec = convert_type<Element>(tSrS);
            if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
              reg2reg(tSrSPrec);
            }
            cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
Tri Dao's avatar
Tri Dao committed
535
536
537
538
539
540
        }

        #pragma unroll 1
        for (; n_block > 0; --n_block) {
            Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
            consumer_wait(pipeline_k, smem_pipe_read_k);
Tri Dao's avatar
Tri Dao committed
541
            warp_scheduler_barrier_sync();
Tri Dao's avatar
Tri Dao committed
542
543
544
545
            flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
            softmax.rescale_o(tOrO, scores_scale);
            consumer_wait(pipeline_v, smem_pipe_read_v);
            flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
Tri Dao's avatar
Tri Dao committed
546
            warp_scheduler_barrier_arrive();
Tri Dao's avatar
Tri Dao committed
547
548
549
550
551
552
553
554
555
556
            warpgroup_wait<1>();
            pipeline_k.consumer_release(smem_pipe_read_k);  // release K
            // auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
            cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
            softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
            warpgroup_wait<0>();
            pipeline_v.consumer_release(smem_pipe_read_v);  // release V
            ++smem_pipe_read_k;
            ++smem_pipe_read_v;
            // softmax.rescale_o(tOrO, scores_scale);
ganeshcolfax's avatar
ganeshcolfax committed
557
558
559
560
561
            auto tSrSPrec = convert_type<Element>(tSrS);
            if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
              reg2reg(tSrSPrec);
            }
            cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
Tri Dao's avatar
Tri Dao committed
562
563
        }
        // Tell warp 0 that smem_q is ready
Tri Dao's avatar
Tri Dao committed
564
        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);
Tri Dao's avatar
Tri Dao committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
        softmax.rescale_o(tOrO, scores_scale);
        consumer_wait(pipeline_v, smem_pipe_read_v);
        flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
        cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
        warpgroup_wait<0>();
        pipeline_v.consumer_release(smem_pipe_read_v);  // release V, otherwise producers will hang
        ++smem_pipe_read_v;

        softmax.rescale_o(tOrO, scores_scale);
        return;
    }

};

} // namespace flash