"sgl-kernel/csrc/gemm/per_token_quant_fp8.cu" did not exist on "b3251e9f40b85159d52563b9ca8276fa0fa03703"
fmha_fprop_fp16_kernel.sm80.cu 10.6 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

Tri Dao's avatar
Tri Dao committed
28
29
30
31
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#include "static_switch.h"
32
#include "fp16_switch.h"
Tri Dao's avatar
Tri Dao committed
33
34
35
36
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
Tri Dao's avatar
Tri Dao committed
37
__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
Tri Dao's avatar
Tri Dao committed
38
39
40
41
    fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}

template<typename Kernel_traits>
Tri Dao's avatar
Tri Dao committed
42
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
Tri Dao's avatar
Tri Dao committed
43
                              const bool configure) {
Tri Dao's avatar
Tri Dao committed
44
45
    constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
    const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
Tri Dao's avatar
Tri Dao committed
46
47
48
49

    if (configure) {
        using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
        constexpr int M = Kernel_traits::Cta_tile_p::M;
Tri Dao's avatar
Tri Dao committed
50
        size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
Tri Dao's avatar
Tri Dao committed
51
52
53
54
55
56
57
        constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
        constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
        size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
        launch_params.elts_per_thread = elts_per_head;
        return;
    }

Tri Dao's avatar
Tri Dao committed
58
59
60
61
    constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
    // Don't need smem_size_softmax_lse if we're not looping
    const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
        + (loop_steps > 1 ? smem_size_softmax_lse : 0);
Tri Dao's avatar
Tri Dao committed
62

63
64
65
    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
    // https://github.com/kokkos/kokkos-kernels/issues/349
    // https://github.com/HazyResearch/flash-attention/issues/21
Tri Dao's avatar
Tri Dao committed
66
    BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        auto kernel = launch_params.params.is_causal
            ? (launch_params.return_softmax
               ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
               : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
            : (launch_params.return_softmax
               ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
               : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
        if( smem_size >= 48 * 1024 ) {
            FMHA_CHECK_CUDA(cudaFuncSetAttribute(
                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
        }
        dim3 grid(launch_params.params.b, launch_params.params.h);
        kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
            launch_params.params);
        FMHA_CHECK_CUDA(cudaPeekAtLastError());
Tri Dao's avatar
Tri Dao committed
82
    });
Tri Dao's avatar
Tri Dao committed
83
84
}

Tri Dao's avatar
Tri Dao committed
85
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
Tri Dao's avatar
Tri Dao committed
86
                        const bool configure) {
87
    FP16_SWITCH(launch_params.params.is_bf16, [&] {
Tri Dao's avatar
Tri Dao committed
88
89
90
91
        auto dprops = at::cuda::getCurrentDeviceProperties();
        if (launch_params.params.d == 16) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
Tri Dao's avatar
Tri Dao committed
92
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            } else if( launch_params.params.seqlen_k == 256 ) {
                using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else {
                // TD [2022-05-15] 512 gives wrong results rn
                // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
                using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            }
        } else if (launch_params.params.d == 32) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else if( launch_params.params.seqlen_k == 256 ) {
                using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else {
                using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            }
        } else if (launch_params.params.d == 64) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else if( launch_params.params.seqlen_k >= 256 ) {
118
119
                using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
120
            }
Tri Dao's avatar
Tri Dao committed
121
122
123
        } else if (launch_params.params.d == 128) {
            if( launch_params.params.seqlen_k == 128 ) {
                using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
Tri Dao's avatar
Tri Dao committed
124
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
125
            } else {
126
                if (dprops->major == 8 && dprops->minor == 0) {
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
                    // TD [2022-06-05] Keep K in registers to reduce register spilling
                    // Gives about 6% speedup compared to using block size 128.
                    using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                } else {  // Need to use the same block size as backward
                    using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                }
Tri Dao's avatar
Tri Dao committed
135
136
            }
        }
Tri Dao's avatar
Tri Dao committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        // if (launch_params.params.d == 64) {
        //     // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //     // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
        //     // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
        //     using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //     run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        // }
        // if (launch_params.params.d == 64) {
        //     if( launch_params.params.seqlen_k == 128 ) {
        //         using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //     } else if( launch_params.params.seqlen_k >= 256 ) {
        //         if (dprops->major == 8 && dprops->minor >= 0) {
        //             using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         } else if (dprops->major == 7 && dprops->minor == 5) {
        //             if (launch_params.is_dropout) { // Need to use the same block size as backward
        //                 using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
        //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //             } else {
        //                 using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
        //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //             }
        //         }
        //     }
        // }
        // if (launch_params.params.d == 128) {
        //     if( launch_params.params.seqlen_k == 128 ) {
        //         using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
        //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //     } else {
        //         if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
        //             // TD [2022-06-05] Keep K in registers to reduce register spilling
        //             // Gives about 6% speedup compared to using block size 128.
        //             using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         } else {  // Need to use the same block size as backward
        //             using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
        //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        //         }
        //     }
        // }
    });
Tri Dao's avatar
Tri Dao committed
180
}