fmha_fprop_fp16_kernel.sm80.cu 10.7 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) {
    fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}

template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
                            const bool configure) {
    bool is_causal = launch_params.params.is_causal;
    // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
    auto kernel = launch_params.is_dropout
        ? (is_causal
           ? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
           : (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
        : (is_causal
           ? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
           : (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));

    constexpr int N = Kernel_traits::Cta_tile_p::N;
    const int loop_steps = (launch_params.params.s + N - 1) / N;
    constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
    // Don't need smem_size_softmax_lse if we're not looping
    const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
        + (loop_steps > 1 ? smem_size_softmax_lse : 0);

    if( smem_size >= 48 * 1024 ) {
        FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
    }

    if (configure) {
        using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
        constexpr int M = Kernel_traits::Cta_tile_p::M;
        size_t STEPS = (launch_params.params.s + M - 1) / M;
        constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
        constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
        size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
        launch_params.elts_per_thread = elts_per_head;
        return;
    }

    dim3 grid(launch_params.params.h, launch_params.params.b);
    kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
        launch_params.params);

    FMHA_CHECK_CUDA(cudaPeekAtLastError());
}

void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
                        const bool configure) {
    if (launch_params.params.d == 16) {
        if( launch_params.params.s == 128 ) {
            using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        } else if( launch_params.params.s == 256 ) {
            using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        } else {
            // TD [2022-05-15] 512 gives wrong results rn
            // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
            using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        }
    } else if (launch_params.params.d == 32) {
        if( launch_params.params.s == 128 ) {
            using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        } else if( launch_params.params.s == 256 ) {
            using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        } else {
            using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        }
    } else if (launch_params.params.d == 64) {
        if( launch_params.params.s == 128 ) {
            using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
        } else if( launch_params.params.s >= 256 ) {
            auto dprops = at::cuda::getCurrentDeviceProperties();
            if (dprops->major == 8 && dprops->minor >= 0) {
                using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else if (dprops->major == 7 && dprops->minor == 5) {
114
115
116
117
118
119
120
                if (launch_params.is_dropout) { // Need to use the same block size as backward
                    using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                } else {
                    using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
                    run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
                }
Tri Dao's avatar
Tri Dao committed
121
            }
Tri Dao's avatar
Tri Dao committed
122
123
        }
    } else if (launch_params.params.d == 128) {
Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
        if( launch_params.params.s == 128 ) {
            using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
            run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
        } else {
            auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
129
            if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
Tri Dao's avatar
Tri Dao committed
130
131
132
133
134
135
136
137
138
                // TD [2022-06-05] Keep K in registers to reduce register spilling
                // Gives about 6% speedup compared to using block size 128.
                using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            } else {  // Need to use the same block size as backward
                using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
                run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
            }
        }
Tri Dao's avatar
Tri Dao committed
139
140
    }
    // if (launch_params.params.d == 64) {
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    //     // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
    //     // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
    //     // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
    //     using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
    //     run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    // }
    // if (launch_params.params.d == 64) {
    //     if( launch_params.params.s == 128 ) {
    //         using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
    //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //     } else if( launch_params.params.s >= 256 ) {
    //         auto dprops = at::cuda::getCurrentDeviceProperties();
    //         if (dprops->major == 8 && dprops->minor >= 0) {
    //             using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
    //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //         } else if (dprops->major == 7 && dprops->minor == 5) {
157
158
159
160
161
162
163
    //             if (launch_params.is_dropout) { // Need to use the same block size as backward
    //                 using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
    //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //             } else {
    //                 using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
    //                 run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //             }
164
165
    //         }
    //     }
Tri Dao's avatar
Tri Dao committed
166
    // }
Tri Dao's avatar
Tri Dao committed
167
168
169
170
171
172
    // if (launch_params.params.d == 128) {
    //     if( launch_params.params.s == 128 ) {
    //         using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
    //         run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //     } else {
    //         auto dprops = at::cuda::getCurrentDeviceProperties();
Tri Dao's avatar
Tri Dao committed
173
    //         if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
179
180
181
182
183
    //             // TD [2022-06-05] Keep K in registers to reduce register spilling
    //             // Gives about 6% speedup compared to using block size 128.
    //             using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u>;
    //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //         } else {  // Need to use the same block size as backward
    //             using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
    //             run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
    //         }
    //     }
    // }
Tri Dao's avatar
Tri Dao committed
184
}