Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
37e32feb
Unverified
Commit
37e32feb
authored
Sep 01, 2023
by
Sophia Wisdom
Committed by
GitHub
Sep 01, 2023
Browse files
Remove commented out code in bwd (#512)
* Remove lots of comments * Remove unused traits
parent
dd8a7549
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
45 deletions
+1
-45
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+1
-39
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+0
-6
No files found.
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
37e32feb
...
...
@@ -126,45 +126,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
template
<
typename
Kernel_traits
,
bool
Is_dropout
>
void
run_flash_bwd
(
Flash_bwd_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
if
(
configure
)
return
;
// dim3 grid(params.b, params.h);
// const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// dim3 grid_m(num_m_block, params.b, params.h);
// if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA)
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
,
configure
);
// } else {
// run_flash_bwd_seqq_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
// }
// // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
// // for cu_seqlens_q as well.
// const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
// BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
// }
// kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
// });
// });
// });
// });
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// }
// kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
run_flash_bwd_seqk_parallel
<
Kernel_traits
,
Is_dropout
>
(
params
,
stream
,
configure
);
}
//
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
37e32feb
...
...
@@ -316,17 +316,11 @@ struct Flash_bwd_kernel_traits : public Base {
static
constexpr
int
kSmemdSCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemPCount
=
size
(
SmemLayoutPdS
{});
static
constexpr
int
kSmemdQCount
=
size
(
SmemLayoutdQ
{});
static
constexpr
int
kSmemdPsumCount
=
kBlockM
;
static
constexpr
int
kSmemQdOSize
=
kSmemQdOCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdSSize
=
kSmemdSCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemPSize
=
kSmemPCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdQSize
=
kSmemdQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemdPsumSize
=
kSmemdPsumCount
*
sizeof
(
ElementAccum
);
static
constexpr
int
kSmemSize
=
kSmemQdOSize
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)
:
std
::
max
(
kSmemKVSize
,
kSmemKVSize
/
2
+
kSmemdSSize
+
std
::
max
(
kSmemPSize
,
kSmemdQSize
)));
static
constexpr
int
kSmemSize1colblock
=
kSmemQdOSize
+
(
!
Is_V_in_regs
?
kSmemKVSize
+
kSmemdSSize
+
kSmemPSize
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment