"vscode:/vscode.git/clone" did not exist on "39b8e4430ebe84c409163b9970145b0d7d53a36c"
Commit 65f723bb authored by Tri Dao's avatar Tri Dao
Browse files

Split bwd into more .cu files to speed up compilation

parent 5ca83a9c
......@@ -800,7 +800,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
});
}
......
......@@ -192,4 +192,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
void run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment