Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ea8a25ca
Commit
ea8a25ca
authored
Jan 21, 2024
by
Tri Dao
Browse files
Remove configure in bwd kernel launch
parent
af01244d
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
97 additions
and
114 deletions
+97
-114
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+8
-24
csrc/flash_attn/src/flash.h
csrc/flash_attn/src/flash.h
+1
-1
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+54
-55
csrc/flash_attn/src/generate_kernels.py
csrc/flash_attn/src/generate_kernels.py
+2
-2
No files found.
csrc/flash_attn/flash_api.cpp
View file @
ea8a25ca
/******************************************************************************
* Copyright (c) 202
3
, Tri Dao.
* Copyright (c) 202
4
, Tri Dao.
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
...
...
@@ -204,7 +204,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
void
run_mha_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
,
bool
force_split_kernel
=
false
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
FWD_
HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
if
(
params
.
num_splits
<=
1
&&
!
force_split_kernel
)
{
// If we don't set it num_splits == 0
run_mha_fwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
}
else
{
...
...
@@ -695,25 +695,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
return
{
out
,
q_padded
,
k_padded
,
v_padded
,
out_padded
,
softmax_lse
,
p
,
rng_state
};
}
void
run_mha_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
if
(
params
.
d
<=
32
)
{
run_mha_bwd_
<
elem_type
,
32
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
64
)
{
run_mha_bwd_
<
elem_type
,
64
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
96
)
{
run_mha_bwd_
<
elem_type
,
96
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
128
)
{
run_mha_bwd_
<
elem_type
,
128
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
160
)
{
run_mha_bwd_
<
elem_type
,
160
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
192
)
{
run_mha_bwd_
<
elem_type
,
192
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
224
)
{
run_mha_bwd_
<
elem_type
,
224
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
d
<=
256
)
{
run_mha_bwd_
<
elem_type
,
256
>
(
params
,
stream
,
configure
);
}
HEADDIM_SWITCH
(
params
.
d
,
[
&
]
{
run_mha_bwd_
<
elem_type
,
kHeadDim
>
(
params
,
stream
);
});
});
}
...
...
@@ -898,7 +884,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params
.
dq_accum_split_stride
=
!
deterministic
?
0
:
dq_accum
.
stride
(
0
);
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
...
@@ -930,7 +915,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
if
(
seqlen_q
>
0
)
{
launch
(
params
,
stream
,
/*configure=*/
false
);
launch
(
params
,
stream
);
}
else
{
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded
.
zero_
();
...
...
@@ -1154,7 +1139,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params
.
dq_accum_split_stride
=
!
deterministic
?
0
:
dq_accum
.
stride
(
0
);
auto
launch
=
&
run_mha_bwd
;
// launch(params, stream, /*configure=*/true);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
...
...
@@ -1186,7 +1170,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
}
if
(
max_seqlen_q
>
0
)
{
launch
(
params
,
stream
,
/*configure=*/
false
);
launch
(
params
,
stream
);
}
else
{
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded
.
zero_
();
...
...
csrc/flash_attn/src/flash.h
View file @
ea8a25ca
...
...
@@ -182,4 +182,4 @@ struct Flash_bwd_params : public Flash_fwd_params {
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_bwd_
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
);
template
<
typename
T
,
int
Headdim
>
void
run_mha_bwd_
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim128
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim128
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim128
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
128
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim128
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim160
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim160
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim160
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
160
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim160
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim192
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim192
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim192
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
192
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim192
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim224
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim224
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
224
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim224
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
256
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
32
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
64
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu
View file @
ea8a25ca
...
...
@@ -5,6 +5,6 @@
#include "flash_bwd_launch_template.h"
template
<
>
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
run_mha_bwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
,
configure
);
void
run_mha_bwd_
<
cutlass
::
half_t
,
96
>
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_bwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
ea8a25ca
...
...
@@ -43,7 +43,7 @@ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd_seqk_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_flash_bwd_seqk_parallel
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid_m
(
num_m_block
,
params
.
b
,
params
.
h
);
const
int
num_n_block
=
(
params
.
seqlen_k
+
Kernel_traits
::
kBlockN
-
1
)
/
Kernel_traits
::
kBlockN
;
...
...
@@ -99,13 +99,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
if
(
configure
)
return
;
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
,
configure
);
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
);
}
template
<
typename
T
>
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim32
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -118,18 +117,18 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
((
3
*
128
+
2
*
128
)
*
Headdim
+
2
*
128
*
128
))
{
// 104 KB
if
constexpr
(
!
Is_dropout
)
{
// We can afford more registers to keep V in registers
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
}
else
{
// 96 KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim64
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -142,39 +141,39 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
128
,
128
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
}
else
{
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream
, configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else {
// }
}
});
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
}
template
<
typename
T
>
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim96
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -188,19 +187,19 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
if
constexpr
(
!
Is_dropout
)
{
// 92KB
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
// 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
4
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim128
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -212,29 +211,29 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if
(
max_smem_per_block
>=
144
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream
, configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
128
,
8
,
2
,
4
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
}
else
{
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream
, configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream
, configure
);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim160
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -246,15 +245,15 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
116
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim192
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -266,23 +265,23 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
136
*
1024
)
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
true
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim224
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
4
,
4
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
});
}
template
<
typename
T
>
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
void
run_mha_bwd_hdim256
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
int
device
;
cudaGetDevice
(
&
device
);
...
...
@@ -294,9 +293,9 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
}
BOOL_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
if
(
max_smem_per_block
>=
176
*
1024
)
{
// H100
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
false
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
else
{
// A100, we don't do double buffering to save smem
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
,
configure
);
run_flash_bwd
<
Flash_bwd_kernel_traits
<
Headdim
,
64
,
64
,
8
,
4
,
2
,
2
,
false
,
true
,
T
>
,
Is_dropout
>
(
params
,
stream
);
}
});
}
csrc/flash_attn/src/generate_kernels.py
View file @
ea8a25ca
...
...
@@ -32,8 +32,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
KERNEL_IMPL_TEMPLATE_BWD
=
"""#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream
, const bool configure
) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream
, configure
);
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
}}
"""
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment