epilogue_fwd_sm90_tma.hpp 9.78 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
/******************************************************************************
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 ******************************************************************************/

#pragma once

#include <cutlass/cutlass.h>
#include "cute/tensor.hpp"

#include "cutlass/gemm/collective/collective_builder.hpp"

Tri Dao's avatar
Tri Dao committed
12
#include "named_barrier.hpp"
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
18
19
#include "utils.h"

namespace flash {

using namespace cute;

// template <int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename Element_>
20
template <typename Ktraits, typename Seqlen_traits>
Tri Dao's avatar
Tri Dao committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct CollectiveEpilogueFwd {

    using Element = typename Ktraits::Element;
    static constexpr int kBlockM = Ktraits::kBlockM;
    static constexpr int kBlockN = Ktraits::kBlockN;
    static constexpr int kHeadDim = Ktraits::kHeadDim;
    using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;

    static constexpr int kNWarps = Ktraits::kNWarps;
    static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
    static constexpr bool Is_WS = kNWarps >= 12;

    static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
    static constexpr int NumMmaThreads = kNThreads - NumCopyThreads;

    using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
        decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
    using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));

    using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
    using SharedStorage = cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>>;

43
    using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
Tri Dao's avatar
Tri Dao committed
44
45
    using TMA_O = decltype(make_tma_copy(
        GmemTiledCopyOTMA{},
46
47
48
49
50
        make_tensor(
            make_gmem_ptr(static_cast<Element*>(nullptr)), 
            typename Seqlen_traits::ShapeT{}, 
            typename Seqlen_traits::StrideT{}
        ),
Tri Dao's avatar
Tri Dao committed
51
52
53
54
        SmemLayoutO{},
        select<0, 2>(TileShape_MNK{}),
        _1{}));  // no mcast for O

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    // These are for storing the output tensor without TMA (e.g., for setting output to zero and var-seq-len)
    static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
    static_assert(kHeadDim % kNumVecElem == 0);
    static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
    static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
    static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
    using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
    using TiledCopyOThrLayout = decltype(cute::make_layout(
        cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
        LayoutRight{}));
    using TiledCopyOValLayout = decltype(cute::make_layout(
        cute::make_shape(_1{}, Int<kNumVecElem>{}),
        LayoutRight{}));
    using TiledCopyO = decltype(make_tiled_copy(
        TiledCopyOAtom{},
        TiledCopyOThrLayout{}, // Thr layout
        TiledCopyOValLayout{} // Val layout
    ));

Tri Dao's avatar
Tri Dao committed
74
75
76
    // Host side kernel arguments
    struct Arguments {
        Element* ptr_O;
77
        typename Seqlen_traits::LayoutT const layout_O;
Tri Dao's avatar
Tri Dao committed
78
        float* ptr_LSE;
79
        typename Seqlen_traits::LayoutLseT const layout_LSE;
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
    };

    // Device side kernel params
    struct Params {
        Element* ptr_O;
85
        typename Seqlen_traits::LayoutT const layout_O;
Tri Dao's avatar
Tri Dao committed
86
        float* ptr_LSE;
87
        typename Seqlen_traits::LayoutLseT const layout_LSE;
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
        TMA_O tma_store_O;
    };

    static Params
    to_underlying_arguments(Arguments const& args) {
93
        Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.layout_O);
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
        TMA_O tma_store_O = make_tma_copy(
            GmemTiledCopyOTMA{},
            mO,
            SmemLayoutO{},
            select<0, 2>(TileShape_MNK{}),
            _1{}); // no mcast for O
100
        return {args.ptr_O, args.layout_O, args.ptr_LSE, args.layout_LSE, tma_store_O};
Tri Dao's avatar
Tri Dao committed
101
102
103
104
105
    }

    /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
    CUTLASS_DEVICE
    static void prefetch_tma_descriptors(Params const& epilogue_params) {
106
107
108
        if constexpr (!Seqlen_traits::kUseVarSeqLen) {
            cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor());
        }
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
115
116
117
118
    }

    template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
    CUTLASS_DEVICE void
    store(Params const& epilogue_params,
          FrgTensorO const& tOrO,
          FrgTensorLSE const& lse,
          SharedStorage& shared_storage,
          TiledMma tiled_mma,
          int thread_idx,
119
120
          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
          const Seqlen_traits& seqlen_traits_q
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
128
129
130
131
132
          ) {

        auto [m_block, bidh, bidb] = block_coord;
        Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
        auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
        auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);

        Tensor tOrO_out = flash::convert_type<Element>(tOrO);
        Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);        // ((Atom,AtomNum), MMA_M, MMA_N)
        Tensor taccOsO = smem_thr_copy_O.partition_D(sO);     // ((Atom,AtomNum),PIPE_M,PIPE_N)

        // Make sure all WGs have finished reading V
Tri Dao's avatar
Tri Dao committed
133
        cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<int>(FwdNamedBarriers::ValueEmpty) /*id*/);
Tri Dao's avatar
Tri Dao committed
134
135
136
137
138
        cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
        cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
        cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp,
                                            cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);

139
140
141
        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);
Tri Dao's avatar
Tri Dao committed
142
143
144
145
146
147
148
149
150
151
152
153
        Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));
        auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
        Tensor taccOcO = thread_mma.partition_C(caccO);                           // (MMA,MMA_M,MMA_K)
        static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
        static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
        // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices.
        Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{});
        CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));                     // MMA_M
        if (get<1>(taccOcO_row(_0{})) == 0) {
            #pragma unroll
            for (int mi = 0; mi < size(lse); ++mi) {
                const int row = get<0>(taccOcO_row(mi));
154
                if (row < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(row) = lse(mi); }
Tri Dao's avatar
Tri Dao committed
155
156
157
            }
        }

158
159
160
161
162
163
        int write_warp_idx = kNWarps - 1;
        if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
            cutlass::arch::NamedBarrier::sync(
                NumMmaThreads + cutlass::NumThreadsPerWarp, 
                cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
            );
Tri Dao's avatar
Tri Dao committed
164
        }
165
166
167
168
169
170
        TiledCopyO gmem_tiled_copy_O;
        flash::write_O<!Seqlen_traits::kUseVarSeqLen, NumCopyThreads>(
            epilogue_params.ptr_O, epilogue_params.tma_store_O, gmem_tiled_copy_O, 
            epilogue_params.layout_O, select<0, 2>(TileShape_MNK{}), sO, 
            m_block, bidh, bidb, seqlen_traits_q, write_warp_idx
        );
Tri Dao's avatar
Tri Dao committed
171
172
173
174
175
176
177
178
    }

    CUTLASS_DEVICE void
    store_tail() {
        tma_store_wait<0>();
    }

    // Write 0 to output and -inf to LSE
179
    template<typename SharedStorage>
Tri Dao's avatar
Tri Dao committed
180
181
    CUTLASS_DEVICE void
    store_zero(
182
183
184
185
186
187
          Params const& epilogue_params,
          SharedStorage& shared_storage,
          int thread_idx,
          cute::tuple<int32_t, int32_t, int32_t> const& block_coord,
          const Seqlen_traits& seqlen_traits_q
          ) {
Tri Dao's avatar
Tri Dao committed
188
        auto [m_block, bidh, bidb] = block_coord;
189
190
191
192
193
194
195
196
197
        Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.layout_O);
        Tensor gO = seqlen_traits_q.get_local_tile_tensor(
            mO, select<0, 2>(TileShape_MNK{}), bidh, bidb
        )(_, _, m_block);  // (M, K)
        Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), epilogue_params.layout_LSE);
        Tensor gLSE = seqlen_traits_q.get_lse_local_tile_tensor(
            mLSE, Shape<Int<kBlockM>>{}, bidh, bidb)(_, m_block);

        TiledCopyO gmem_tiled_copy_O;
Tri Dao's avatar
Tri Dao committed
198
199
200
201
202
203
204
205
206
207
        auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
        Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
        Tensor tOrO = make_fragment_like(tOgO);
        clear(tOrO);
        // Construct identity layout for sO
        Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
        // Repeat the partitioning with identity layouts
        Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
        Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
        #pragma unroll
208
        for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
Tri Dao's avatar
Tri Dao committed
209
210
        // Clear_OOB_K must be false since we don't want to write zeros to gmem
        flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
211
            gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
Tri Dao's avatar
Tri Dao committed
212
213
        );
        static_assert(kBlockM <= NumMmaThreads);
214
        if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
    }

};

} // namespace flash