Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
bc2c2102
Commit
bc2c2102
authored
Jul 11, 2022
by
Tri Dao
Browse files
Don't nest BOOL_SWITCH to work around gcc 7 bug
parent
d1fc80a3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
32 deletions
+38
-32
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+20
-18
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+18
-14
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
bc2c2102
...
@@ -27,25 +27,27 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
...
@@ -27,25 +27,27 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
auto
kernel
=
params
.
is_causal
auto
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
>
;
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
if
(
params
.
seqlen_k
==
blocksize_c
)
{
if
(
params
.
seqlen_k
==
blocksize_c
)
{
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
kernel
=
params
.
is_causal
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
/*loop_steps=*/
1
>
;
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
1
>
}
else
if
(
params
.
seqlen_k
==
blocksize_c
*
2
)
{
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
1
>
;
kernel
=
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
}
else
if
(
params
.
seqlen_k
==
blocksize_c
*
2
)
{
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
/*loop_steps=*/
2
>
;
kernel
=
params
.
is_causal
}
?
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
/*loop_steps=*/
2
>
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
:
&
fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
/*loop_steps=*/
2
>
;
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
}
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
}
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
dim3
grid
(
params
.
b
,
params
.
h
);
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
dim3
grid
(
params
.
b
,
params
.
h
);
});
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
});
}
}
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
bc2c2102
...
@@ -59,21 +59,25 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -59,21 +59,25 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
const
int
smem_size
=
fmha
::
get_dynamic_smem_size
<
Kernel_traits
>
()
+
(
loop_steps
>
1
?
smem_size_softmax_lse
:
0
);
+
(
loop_steps
>
1
?
smem_size_softmax_lse
:
0
);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
auto
kernel
=
launch_params
.
params
.
is_causal
BOOL_SWITCH
(
launch_params
.
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
?
(
launch_params
.
return_softmax
auto
kernel
=
&
fmha_fprop_fp16_sm80_loop_kernel
<
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
true
>
Kernel_traits
,
IsDropoutConst
,
IsCausalConst
,
ReturnSoftmaxConst
>
;
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
false
>
)
if
(
smem_size
>=
48
*
1024
)
{
:
(
launch_params
.
return_softmax
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
?
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
true
>
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
:
&
fmha_fprop_fp16_sm80_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
,
false
>
);
}
if
(
smem_size
>=
48
*
1024
)
{
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
launch_params
.
params
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
dim3
grid
(
launch_params
.
params
.
b
,
launch_params
.
params
.
h
);
});
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
});
launch_params
.
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment