Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
65f723bb
Commit
65f723bb
authored
Jul 23, 2024
by
Tri Dao
Browse files
Split bwd into more .cu files to speed up compilation
parent
5ca83a9c
Changes
89
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
117 additions
and
69 deletions
+117
-69
csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
+3
-3
csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
+3
-3
csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
+3
-3
csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
+3
-3
csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
+3
-3
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+43
-45
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
+1
-1
No files found.
csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
65f723bb
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
,
true
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
#include "flash_bwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
,
false
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_bwd_hdim32
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
65f723bb
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
,
true
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
#include "flash_bwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
,
false
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
65f723bb
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
,
true
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
#include "flash_bwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
,
false
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_bwd_hdim64
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
65f723bb
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
,
true
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
#include "flash_bwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
,
false
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
65f723bb
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
,
true
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
#include "flash_bwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
,
false
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_bwd_hdim96
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
65f723bb
...
@@ -65,7 +65,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
...
@@ -65,7 +65,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash
::
convert_dKV
<
Kernel_traits
>
(
params
);
flash
::
convert_dKV
<
Kernel_traits
>
(
params
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_bwd_seqk_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_bwd_seqk_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid_m
(
num_m_block
,
params
.
b
,
params
.
h
);
dim3
grid_m
(
num_m_block
,
params
.
b
,
params
.
h
);
...
@@ -90,24 +90,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
...
@@ -90,24 +90,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
constexpr
int
smem_size_dq_dk_dv
=
Kernel_traits
::
kSmemSize1colblock
;
constexpr
int
smem_size_dq_dk_dv
=
Kernel_traits
::
kSmemSize1colblock
;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
params
.
is_causal
,
Is_local
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
ALIBI_SWITCH
(
params
.
alibi_slopes_ptr
!=
nullptr
,
Has_alibi
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
SOFTCAP_SWITCH
(
params
.
softcap
>
0.0
,
Is_softcap
,
[
&
]
{
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
// If Is_local, set Is_causal to false
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
>
;
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
&&
!
Is_softcap
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
}
kernel
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
kernel
<<<
grid_n
,
Kernel_traits
::
kNThreads
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
});
});
});
});
});
});
});
...
@@ -123,14 +121,14 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
...
@@ -123,14 +121,14 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
#ifndef FLASHATTENTION_DISABLE_BACKWARD
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
#endif
#endif
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
int
device
;
int
device
;
...
@@ -144,17 +142,17 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -144,17 +142,17 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
((
3
*
128
+
2
*
128
)
*
Headdim
+
2
*
128
*
128
))
{
// 104 KB
if
(
max_smem_per_block
>=
2
*
((
3
*
128
+
2
*
128
)
*
Headdim
+
2
*
128
*
128
))
{
// 104 KB
if
constexpr
(
!
Is_dropout
)
{
// We can afford more registers to keep V in registers
if
constexpr
(
!
Is_dropout
)
{
// We can afford more registers to keep V in registers
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
}
else
{
// 96 KB
}
else
{
// 96 KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
int
device
;
int
device
;
...
@@ -174,13 +172,13 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -174,13 +172,13 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if
(
max_smem_per_block
>=
144
*
1024
)
{
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// This has a lot of register spilling
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
}
else
{
}
else
{
// if (params.h == params.h_k) {
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else {
// } else {
...
@@ -199,7 +197,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -199,7 +197,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
int
device
;
int
device
;
...
@@ -214,18 +212,18 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -214,18 +212,18 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
constexpr
(
!
Is_dropout
)
{
// 92KB
if
constexpr
(
!
Is_dropout
)
{
// 92KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
// 116 KB
}
else
{
// 116 KB
// This is faster for dropout since we don't have many registers to spare
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
}
else
{
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
int
device
;
int
device
;
...
@@ -243,7 +241,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -243,7 +241,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if
(
max_smem_per_block
>=
144
*
1024
)
{
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
...
@@ -251,7 +249,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -251,7 +249,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
}
else
{
}
else
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
...
@@ -259,7 +257,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -259,7 +257,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
int
device
;
int
device
;
...
@@ -272,14 +270,14 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -272,14 +270,14 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
}
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
true
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
int
device
;
int
device
;
...
@@ -292,14 +290,14 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -292,14 +290,14 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
}
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
136
*
1024
)
{
if
(
max_smem_per_block
>=
136
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
true
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
}
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
...
@@ -312,12 +310,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
...
@@ -312,12 +310,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
}
}
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
176
*
1024
)
{
// H100
if
(
max_smem_per_block
>=
176
*
1024
)
{
// H100
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
if
(
max_smem_per_block
>=
144
*
1024
)
{
// A100, we don't do double buffering to save smem
}
else
if
(
max_smem_per_block
>=
144
*
1024
)
{
// A100, we don't do double buffering to save smem
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
true
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
// sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
}
else
{
// sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
32
,
8
,
4
,
1
,
2
,
true
,
true
,
T
>
,
false
>
(
params
,
stream
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
32
,
8
,
4
,
1
,
2
,
true
,
true
,
T
>
,
false
,
Is_causal
>
(
params
,
stream
);
}
}
}
}
});
});
...
...
csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
View file @
65f723bb
// Copyright (c) 202
3
, Tri Dao.
// Copyright (c) 202
4
, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
// This file is auto-generated. See "generate_kernels.py"
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment