Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
65f723bb
Commit
65f723bb
authored
Jul 23, 2024
by
Tri Dao
Browse files
Split bwd into more .cu files to speed up compilation
parent
5ca83a9c
Changes
89
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
26 additions
and
15 deletions
+26
-15
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
...flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/generate_kernels.py
csrc/flash_attn/src/generate_kernels.py
+5
-8
setup.py
setup.py
+14
-0
No files found.
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/generate_kernels.py
View file @
65f723bb
...
...
@@ -33,8 +33,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Fla
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}
, {IS_CAUSAL}
>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}
, {IS_CAUSAL}
>(params, stream);
}}
"""
...
...
@@ -55,7 +55,7 @@ class Kernel:
)
elif
self
.
direction
==
"bwd"
:
return
KERNEL_IMPL_TEMPLATE_BWD
.
format
(
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
DTYPE
=
DTYPE_MAP
[
self
.
dtype
],
HEAD_DIM
=
self
.
head_dim
,
IS_CAUSAL
=
self
.
is_causal
)
else
:
return
KERNEL_IMPL_TEMPLATE_FWD_SPLIT
.
format
(
...
...
@@ -68,16 +68,13 @@ class Kernel:
def
get_all_kernels
()
->
List
[
Kernel
]:
for
direction
in
[
"fwd"
,
"fwd_split"
]:
for
direction
in
[
"fwd"
,
"fwd_split"
,
"bwd"
]:
for
dtype
,
head_dim
,
is_causal
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
IS_CAUSAL
,
SM
):
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
is_causal
,
direction
=
direction
)
for
direction
in
[
"bwd"
]:
for
dtype
,
head_dim
,
sm
in
itertools
.
product
(
DTYPE_MAP
.
keys
(),
HEAD_DIMENSIONS
,
SM
):
yield
Kernel
(
sm
=
sm
,
dtype
=
dtype
,
head_dim
=
head_dim
,
is_causal
=
"false"
,
direction
=
direction
)
def
write_kernel
(
kernel
:
Kernel
,
autogen_dir
:
Path
)
->
None
:
prelude
=
"""// Copyright (c) 202
3
, Tri Dao.
prelude
=
"""// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
\n
"""
...
...
setup.py
View file @
65f723bb
...
...
@@ -222,6 +222,20 @@ if not SKIP_CUDA_BUILD and not IS_ROCM:
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment