Commit 908511b2 authored by Tri Dao's avatar Tri Dao
Browse files

Split into more .cu files to speed up compilation

parent 1d536d7d
...@@ -238,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -238,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params &params,
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] { FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] { HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_fwd_<elem_type, kHeadDim>(params, stream); if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
} else { run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream); } else {
} run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}
});
}); });
}); });
} }
......
...@@ -188,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params { ...@@ -188,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream); run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream); run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream); run_mha_fwd_hdim224<cutlass::half_t, false>(params, stream);
} }
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
}
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment