epilogue_fwd_sm90_tma.hpp 14.2 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
/******************************************************************************
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 ******************************************************************************/

#pragma once

#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"

#include "cutlass/gemm/collective/collective_builder.hpp"

Tri Dao's avatar
Tri Dao committed
12
#include "named_barrier.hpp"
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
18
19
#include "utils.h"

namespace flash {

using namespace cute;

// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
20
template <typename Ktraits, typename Seqlen_traits>
Tri Dao's avatar
Tri Dao committed
21
22
struct CollectiveEpilogueFwd {

23
    using Element = typename Ktraits::OutputType;    
Tri Dao's avatar
Tri Dao committed
24
25
26
27
28
29
30
    static constexpr int kBlockM = Ktraits::kBlockM;
    static constexpr int kBlockN = Ktraits::kBlockN;
    static constexpr int kHeadDim = Ktraits::kHeadDim;
    using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;

    static constexpr int kNWarps = Ktraits::kNWarps;
    static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
31
    static constexpr bool Is_WS = kNWarps >= 12;    
Tri Dao's avatar
Tri Dao committed
32
33
34
35
36
37
38
39
40
41
42

    static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
    static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;

    using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
    using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));

    using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
    using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;

43
    using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
Tri Dao's avatar
Tri Dao committed
44
45
    using TMA_O = decltype(make_tma_copy(
        GmemTiledCopyOTMA{},
46
47
48
49
50
        make_tensor(
            make_gmem_ptr(static_cast<Element*>(nullptr)), 
            typename Seqlen_traits::ShapeT{}, 
            typename Seqlen_traits::StrideT{}
        ),
Tri Dao's avatar
Tri Dao committed
51
52
53
54
        SmemLayoutO{},
        select<0, 2>(TileShape_MNK{}),
        _1{}));  // no mcast for O

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
    static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
    static_assert(kHeadDim % kNumVecElem == 0);
    static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
    static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
    static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
    using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
    using TiledCopyOThrLayout = decltype(cute::make_layout(
        cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
        LayoutRight{}));
    using TiledCopyOValLayout = decltype(cute::make_layout(
        cute::make_shape(_1{}, Int<kNumVecElem>{}),
        LayoutRight{}));
    using TiledCopyO = decltype(make_tiled_copy(
        TiledCopyOAtom{},
        TiledCopyOThrLayout{}, // Thr layout
        TiledCopyOValLayout{} // Val layout
    ));

74
75
76
77
78
79
80
81
82
83
    // used for rmem -> smem O copy in fp8 kernel to undo column permutation
    using ThreadLayoutrO = Layout<Shape<_8, Int<kBlockM/16>, _4, _1>,
                                 Stride<_4, _32, _1, _0>>;
    using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
                                Stride<_0, _2, Stride<_4, _1>, _8>>;
    using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint16_t>, Element>{},
                      ThreadLayoutrO{}, ValueLayoutrO{}));
    using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
    using SmemLayoutrO = decltype(composition(SmemLayoutO{}, Layout<TiledCopyShaperO>{}));

Tri Dao's avatar
Tri Dao committed
84
85
86
    // Host side kernel arguments
    struct Arguments {
        Element* ptr_O;
87
        typename Seqlen_traits::LayoutT const layout_O;
Tri Dao's avatar
Tri Dao committed
88
        float* ptr_LSE;
89
        typename Seqlen_traits::LayoutLseT const layout_LSE;
Tri Dao's avatar
Tri Dao committed
90
91
92
93
94
    };

    // Device side kernel params
    struct Params {
        Element* ptr_O;
95
        typename Seqlen_traits::LayoutT const layout_O;
Tri Dao's avatar
Tri Dao committed
96
        float* ptr_LSE;
97
        typename Seqlen_traits::LayoutLseT const layout_LSE;
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
        TMA_O tma_store_O;
    };

    static Params
    to_underlying_arguments(Arguments const& args) {
103
        Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
Tri Dao's avatar
Tri Dao committed
104
105
106
107
108
109
        TMA_O tma_store_O = make_tma_copy(
            GmemTiledCopyOTMA{},
            mO,
            SmemLayoutO{},
            select<0, 2>(TileShape_MNK{}),
            _1{}); // no mcast for O
110
        return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
    }

    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
    CUTLASS_DEVICE
    static void prefetch_tma_descriptors(Params const& epilogue_params) {
116
117
118
        if constexpr (!Seqlen_traits::kUseVarSeqLen) {
            cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
        }
Tri Dao's avatar
Tri Dao committed
119
120
121
122
123
124
125
126
127
128
    }

    template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
    CUTLASS_DEVICE void
    store(Params const& epilogue_params,
          FrgTensorO const& tOrO,
          FrgTensorLSE const& lse,
          SharedStorage& shared_storage,
          TiledMma tiled_mma,
          int thread_idx,
129
130
          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
          const Seqlen_traits& seqlen_traits_q
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
139
140
141
142
          ) {

        auto [m_block, bidh, bidb] = block_coord;
        Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
        auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
        auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);

        Tensor tOrO_out = flash::convert_type<Element>(tOrO);
        Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);        // ((Atom,AtomNum), MMA_M, MMA_N)
        Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

        // Make sure all WGs have finished reading V
Tri Dao's avatar
Tri Dao committed
143
        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
Tri Dao's avatar
Tri Dao committed
144
145
146
147
148
        cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
        cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);

149
150
151
        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
Tri Dao's avatar
Tri Dao committed
152
153
154
155
156
157
158
159
160
161
162
        Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
        static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
        static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
        // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
        Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
        CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M
        if (get<1>(taccOcO_row(_0{})) == 0) {
            #pragma unroll
            for (int mi = 0; mi < size(lse); ++mi) {
163
                const int row = get<0>(taccOcO_row(mi));                
164
                if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
Tri Dao's avatar
Tri Dao committed
165
166
167
            }
        }

168
169
170
171
172
173
        int write_warp_idx = kNWarps - 1;
        if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
            cutlass::arch::NamedBarrier::sync(
                NumMmaThreads + cutlass::NumThreadsPerWarp, 
                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
            );
Tri Dao's avatar
Tri Dao committed
174
        }
175
176
177
178
179
180
        TiledCopyO gmem_tiled_copy_O;
        flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
            epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, 
            epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, 
            m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
        );
Tri Dao's avatar
Tri Dao committed
181
182
    }

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
    CUTLASS_DEVICE void
    store_fp8(Params const& epilogue_params,
          FrgTensorO const& tOrO,
          FrgTensorLSE const& lse,
          SharedStorage& shared_storage,
          TiledMma tiled_mma,
          int thread_idx,
          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
          const Seqlen_traits& seqlen_traits_q
          ) {
        // using SmemLayoutrO = typename Ktraits::SmemLayoutrO;
        // using TiledCopyrO = typename Ktraits::TiledCopyrO;
        auto [m_block, bidh, bidb] = block_coord;        

        TiledCopyrO rmem_tiled_copy_O;
        Tensor sOacc = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutrO{});
        auto rmem_thr_copy_O = rmem_tiled_copy_O.get_thread_slice(thread_idx);
        
        Tensor taccOsO = rmem_thr_copy_O.partition_D(sOacc);
        Tensor tOrO_out = flash::convert_type<Element>(tOrO); // Element is Ktraits::OutputType
        Tensor taccOrO = make_tensor(tOrO_out.data(), shape(taccOsO));

        // Make sure all WGs have finished reading V
        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);        
        cute::copy(rmem_tiled_copy_O, taccOrO, taccOsO);
        cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
        
        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
        Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
        static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
        static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
        // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
        Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
        CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M        
        int const seqlen_q = [&] {
            if constexpr(Seqlen_traits::kUseVarSeqLen) { return seqlen_traits_q.actual_seq_len; }
            else { return shape<2>(epilogue_params.layout_LSE); }
        }();        
        if (get<1>(taccOcO_row(_0{})) == 0) {
            #pragma unroll
            for (int mi = 0; mi < size(lse); ++mi) {
                const int row = get<0>(taccOcO_row(mi));
                if (row < seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
            }
        }
        
        int write_warp_idx = kNWarps - 1;
        if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
            cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp,
                                              cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
        }
        TiledCopyO gmem_tiled_copy_O;
        Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
        flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
            epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, 
            epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, 
            m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
        );
    }

Tri Dao's avatar
Tri Dao committed
250
251
252
253
254
255
    CUTLASS_DEVICE void
    store_tail() {
        tma_store_wait<0>();
    }

    // Write 0 to output and -inf to LSE
256
    template<typename SharedStorage>
Tri Dao's avatar
Tri Dao committed
257
258
    CUTLASS_DEVICE void
    store_zero(
259
260
261
262
263
264
          Params const& epilogue_params,
          SharedStorage& shared_storage,
          int thread_idx,
          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
          const Seqlen_traits& seqlen_traits_q
          ) {
Tri Dao's avatar
Tri Dao committed
265
        auto [m_block, bidh, bidb] = block_coord;
266
267
268
269
270
271
272
273
274
        Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
        Tensor gO = seqlen_traits_q.get_local_tile_tensor(
            mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
        )(_, _, m_block);  // (M, K)
        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);

        TiledCopyO gmem_tiled_copy_O;
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
284
        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
        Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
        Tensor tOrO = make_fragment_like(tOgO);
        clear(tOrO);
        // Construct identity layout for sO
        Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
        // Repeat the partitioning with identity layouts
        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
        #pragma unroll
285
        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
Tri Dao's avatar
Tri Dao committed
286
287
        // Clear_OOB_K must be false since we don't want to write zeros to gmem
        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
288
            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
Tri Dao's avatar
Tri Dao committed
289
290
        );
        static_assert(kBlockM <= NumMmaThreads);
291
        if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
Tri Dao's avatar
Tri Dao committed
292
293
294
295
296
    }

};

} // namespace flash