Commit 65f723bb authored by Tri Dao's avatar Tri Dao
Browse files

Split bwd into more .cu files to speed up compilation

parent 5ca83a9c
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
// Copyright (c) 2023, Tri Dao.
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
......
......@@ -33,8 +33,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Fla
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""
......@@ -55,7 +55,7 @@ class Kernel:
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
......@@ -68,16 +68,13 @@ class Kernel:
def get_all_kernels() -> List[Kernel]:
for direction in ["fwd", "fwd_split"]:
for direction in ["fwd", "fwd_split", "bwd"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
prelude = """// Copyright (c) 2023, Tri Dao.
prelude = """// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"\n
"""
......
......@@ -222,6 +222,20 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment