fmha_bwd_hdim128.cu 497 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimentions to different files to speed up compilation.

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
    // work around for MSVC issue
    FP16_SWITCH(params.is_bf16, [&] {
        using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
        run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
    });
}