Commit 908511b2 authored by Tri Dao's avatar Tri Dao
Browse files

Split into more .cu files to speed up compilation

parent 1d536d7d
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);
......@@ -16,17 +16,18 @@ DTYPE_MAP = {
SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
IS_CAUSAL = ["false", "true"]
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}
"""
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params &params, cudaStream_t stream);
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
"""
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
......@@ -43,13 +44,14 @@ class Kernel:
sm: int
dtype: str
head_dim: int
is_causal: bool
direction: str
@property
def template(self) -> str:
if self.direction == "fwd":
return KERNEL_IMPL_TEMPLATE_FWD.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
elif self.direction == "bwd":
return KERNEL_IMPL_TEMPLATE_BWD.format(
......@@ -57,18 +59,21 @@ class Kernel:
)
else:
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
)
@property
def filename(self) -> str:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"
def get_all_kernels() -> List[Kernel]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
for direction in ["fwd", "bwd", "fwd_split"]:
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
for direction in ["fwd", "fwd_split"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
for direction in ["bwd"]:
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction)
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
......
......@@ -151,6 +151,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
......@@ -183,6 +199,22 @@ if not SKIP_CUDA_BUILD:
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment