Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
43617dea
Commit
43617dea
authored
Sep 18, 2023
by
Tri Dao
Browse files
Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation
parent
fa3ddcba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
5 deletions
+9
-5
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+4
-2
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+4
-2
training/Dockerfile
training/Dockerfile
+1
-1
No files found.
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
43617dea
...
@@ -62,7 +62,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -62,7 +62,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
,
IsEvenKConst
>
;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
...
@@ -101,7 +102,8 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
...
@@ -101,7 +102,8 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
IsCausalConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_N
,
IsEvenNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
&&
IsEvenKConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
43617dea
...
@@ -44,7 +44,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -44,7 +44,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
return_softmax
,
ReturnSoftmaxConst
,
[
&
]
{
// Will only return softmax if dropout, to reduce compilation time.
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
&&
IsEvenKConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
...
@@ -76,7 +77,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -76,7 +77,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
BOOL_SWITCH
(
params
.
knew_ptr
!=
nullptr
,
Append_KV
,
[
&
]
{
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
,
IsEvenKConst
,
Split
,
Append_KV
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
...
...
training/Dockerfile
View file @
43617dea
...
@@ -87,7 +87,7 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
...
@@ -87,7 +87,7 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.2.3.post2
RUN
pip
install
flash-attn
==
2.2.3.post2
# Install CUDA extensions for
cross-entropy,
fused dense, layer norm
# Install CUDA extensions for fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.2.3.post2
\
&&
cd
flash-attention
&&
git checkout v2.2.3.post2
\
&&
cd
csrc/layer_norm
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/layer_norm
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment