fmha_bwd_hdim64.cu 1.61 KB
Newer Older
1
2
// Copyright (c) 2022, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
// Splitting the different head dimensions to different files to speed up compilation.
4
5
6
7

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
8
    FP16_SWITCH(params.is_bf16, ({
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
        auto dprops = at::cuda::getCurrentDeviceProperties();
        if (params.seqlen_k == 128) {
            using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
            run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
        } else if (params.seqlen_k >= 256) {
            if (dprops->major == 8 && dprops->minor == 0) {
                // Don't share smem for K & V, and don't keep V in registers
                // This speeds things up by 2-3% by avoiding register spills, but it
                // uses more shared memory, which is fine on A100 but not other GPUs.
                // For other GPUs, we keep V in registers.
                using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
                run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
            } else if (dprops->major == 8 && dprops->minor > 0) {
                using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
                run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
            } else if (dprops->major == 7 && dprops->minor == 5) {
                using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
                run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
            }
        }
29
    }));
30
}