"llm/llama.go" did not exist on "7c71c10d4fc51f2a26961d902f3ed660af789c93"
fmha_fprop_fp16_kernel.sm80.cu 12.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

Tri Dao's avatar
Tri Dao committed
28
29
30
31
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include "static_switch.h"
32
#include "fp16_switch.h"
Tri Dao's avatar
Tri Dao committed
33
34
35
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"

Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
    float max_efficiency = 0.f;
    std::vector<float> efficiency;
    efficiency.reserve(max_splits);
    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
        float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
        float eff = n_waves / ceil(n_waves);
        // printf("num_splits = %d, eff = %f\n", num_splits, eff);
        if (eff > max_efficiency) { max_efficiency = eff; }
        efficiency.push_back(eff);
    }
    for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
        if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
            // printf("num_splits chosen = %d\n", num_splits);
            return num_splits;
        }
    }
    return 1;
}

Tri Dao's avatar
Tri Dao committed
62
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
Tri Dao's avatar
Tri Dao committed
63
__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
Tri Dao's avatar
Tri Dao committed
64
65
66
67
    fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}

template<typename Kernel_traits>
Tri Dao's avatar
Tri Dao committed
68
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
Tri Dao's avatar
Tri Dao committed
69
                              const bool configure) {
Tri Dao's avatar
Tri Dao committed
70
71
    constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
    const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
Tri Dao's avatar
Tri Dao committed
72
73
74
75

    if (configure) {
        using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
        constexpr int M = Kernel_traits::Cta_tile_p::M;
Tri Dao's avatar
Tri Dao committed
76
        size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
        constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
        constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
        size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
        launch_params.elts_per_thread = elts_per_head;
        return;
    }

Tri Dao's avatar
Tri Dao committed
84
85
86
87
    constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
    // Don't need smem_size_softmax_lse if we're not looping
    const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
        + (loop_steps > 1 ? smem_size_softmax_lse : 0);
Tri Dao's avatar
Tri Dao committed
88

89
90
91
    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
    // https://github.com/kokkos/kokkos-kernels/issues/349
    // https://github.com/HazyResearch/flash-attention/issues/21
Tri Dao's avatar
Tri Dao committed
92
    BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
93
94
95
96
97
98
99
100
101
102
103
        auto kernel = launch_params.params.is_causal
            ? (launch_params.return_softmax
               ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
               : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
            : (launch_params.return_softmax
               ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
               : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
        if( smem_size >= 48 * 1024 ) {
            FMHA_CHECK_CUDA(cudaFuncSetAttribute(
                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
        }
Tri Dao's avatar
Tri Dao committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        // Automatically set num_splits to maximize occupancy
        if (launch_params.params.num_splits <= 0) {
            int ctas_per_sm;
            cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
                &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
            auto dprops = at::cuda::getCurrentDeviceProperties();
            // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
            constexpr int M = Kernel_traits::Cta_tile_p::M;
            launch_params.params.num_splits = num_splits_heuristic_fwd(
                launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
                ctas_per_sm,
                /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
            );
        }
        dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
119
120
121
        kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
            launch_params.params);
        FMHA_CHECK_CUDA(cudaPeekAtLastError());
Tri Dao's avatar
Tri Dao committed
122
    });
Tri Dao's avatar
Tri Dao committed
123
124
}

Tri Dao's avatar
Tri Dao committed
125
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
Tri Dao's avatar
Tri Dao committed
126
                        const bool configure) {
127
    FP16_SWITCH(launch_params.params.is_bf16, [&] {
Tri Dao's avatar
Tri Dao committed
128
129
130
131
        auto dprops = at::cuda::getCurrentDeviceProperties();
        if (launch_params.params.d == 16) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
Tri Dao's avatar
Tri Dao committed
132
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
133
134
135
136
137
138
139
140
141
142
143
144
145
            } else if( launch_params.params.seqlen_k == 256 ) {
                using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else {
                // TD [2022-05-15] 512 gives wrong results rn
                // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
                using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            }
        } else if (launch_params.params.d == 32) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
146
            } else if( launch_params.params.seqlen_k >= 256 ) {
Tri Dao's avatar
Tri Dao committed
147
148
149
150
151
152
153
154
                using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            }
        } else if (launch_params.params.d == 64) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else if( launch_params.params.seqlen_k >= 256 ) {
155
156
                using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
157
            }
Tri Dao's avatar
Tri Dao committed
158
159
160
        } else if (launch_params.params.d == 128) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
Tri Dao's avatar
Tri Dao committed
161
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
162
            } else {
163
                if (dprops->major == 8 && dprops->minor == 0) {
Tri Dao's avatar
Tri Dao committed
164
165
166
167
168
169
170
171
                    // TD [2022-06-05] Keep K in registers to reduce register spilling
                    // Gives about 6% speedup compared to using block size 128.
                    using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                } else {  // Need to use the same block size as backward
                    using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                }
Tri Dao's avatar
Tri Dao committed
172
173
            }
        }
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        // if (launch_params.params.d == 64) {
        //     // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //     // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
        //     // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
        //     using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //     run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        // }
        // if (launch_params.params.d == 64) {
        //     if( launch_params.params.seqlen_k == 128 ) {
        //         using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //     } else if( launch_params.params.seqlen_k >= 256 ) {
        //         if (dprops->major == 8 && dprops->minor >= 0) {
        //             using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         } else if (dprops->major == 7 && dprops->minor == 5) {
        //             if (launch_params.is_dropout) { // Need to use the same block size as backward
        //                 using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //             } else {
        //                 using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
        //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //             }
        //         }
        //     }
        // }
        // if (launch_params.params.d == 128) {
        //     if( launch_params.params.seqlen_k == 128 ) {
        //         using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
        //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //     } else {
        //         if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
        //             // TD [2022-06-05] Keep K in registers to reduce register spilling
        //             // Gives about 6% speedup compared to using block size 128.
        //             using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         } else {  // Need to use the same block size as backward
        //             using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         }
        //     }
        // }
    });
Tri Dao's avatar
Tri Dao committed
217
}