Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
908511b2
Commit
908511b2
authored
Jul 10, 2024
by
Tri Dao
Browse files
Split into more .cu files to speed up compilation
parent
1d536d7d
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
235 additions
and
162 deletions
+235
-162
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
+10
-0
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
+2
-2
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+128
-146
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
+7
-0
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
+1
-1
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
...lash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
+7
-0
No files found.
csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
256
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim256
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
32
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim32
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim32
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
64
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim64
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
bfloat16_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
true
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
true
>
(
params
,
stream
);
}
csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
View file @
908511b2
...
@@ -5,6 +5,6 @@
...
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
<
>
template
<
>
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_
<
cutlass
::
half_t
,
96
,
false
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim96
<
cutlass
::
half_t
>
(
params
,
stream
);
run_mha_fwd_hdim96
<
cutlass
::
half_t
,
false
>
(
params
,
stream
);
}
}
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
908511b2
...
@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -95,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
,
bool
Is_causal
>
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_flash_splitkv_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Is_Q_in_regs
,
"SplitKV implementation does not support Is_Q_in_regs"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
static_assert
(
!
Kernel_traits
::
Share_Q_K_smem
,
"SplitKV implementation does not support Share_Q_K_smem"
);
...
@@ -104,7 +104,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -104,7 +104,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
b
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
h
);
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
EVENK_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
LOCAL_SWITCH
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
,
Is_local
,
[
&
]
{
...
@@ -131,7 +130,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -131,7 +130,6 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
});
});
});
});
});
});
});
if
(
params
.
num_splits
>
1
)
{
if
(
params
.
num_splits
>
1
)
{
// We want kBlockM to be as small as possible for more parallelism.
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
...
@@ -159,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -159,31 +157,28 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
}
}
template
<
typename
T
,
int
Headdim
>
template
<
typename
T
,
int
Headdim
,
bool
Is_causal
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
constexpr
static
int
kBlockM
=
64
;
// Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
constexpr
static
int
kBlockN
=
Headdim
<=
64
?
256
:
(
Headdim
<=
128
?
128
:
64
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>>
(
params
,
stream
);
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
kBlockM
,
kBlockN
,
4
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim32
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
32
;
constexpr
static
int
Headdim
=
32
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
4
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim64
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
64
;
constexpr
static
int
Headdim
=
64
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (64 x 256) is 27% slower for seqlen=2k
...
@@ -198,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -198,16 +193,14 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim96
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
96
;
constexpr
static
int
Headdim
=
96
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if
(
is_sm8x
)
{
if
(
is_sm8x
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
...
@@ -224,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -224,16 +217,14 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
...
@@ -261,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -261,16 +252,14 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
}
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
160
;
constexpr
static
int
Headdim
=
160
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>
0
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, H100, 128 x 32 is the fastest.
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
// and 128 x 64 with 8 warps is the fastest for non-causal.
...
@@ -291,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -291,14 +280,12 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim192
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
192
;
constexpr
static
int
Headdim
=
192
;
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
constexpr
(
!
Is_dropout
)
{
if
constexpr
(
!
Is_dropout
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -310,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -310,10 +297,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim224
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
224
;
constexpr
static
int
Headdim
=
224
;
int
device
;
int
device
;
...
@@ -326,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -326,7 +312,6 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
))
{
// 112 KB
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
64
,
8
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
}
else
{
...
@@ -339,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -339,10 +324,9 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
template
<
typename
T
>
template
<
typename
T
,
bool
Is_causal
>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
int
device
;
int
device
;
...
@@ -357,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -357,7 +341,6 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
}
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
DROPOUT_SWITCH
(
params
.
p_dropout
<
1.
f
,
Is_dropout
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// For A100, we want to run with 128 x 64 (128KB smem).
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
if
(
max_smem_per_block
>=
2
*
Headdim
*
(
128
+
2
*
64
)
&&
max_smem_per_sm
<
4
*
Headdim
*
(
64
+
2
*
64
))
{
...
@@ -370,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -370,5 +353,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
// 96 KB
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
});
}
}
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
View file @
908511b2
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
View file @
908511b2
...
@@ -4,4 +4,4 @@
...
@@ -4,4 +4,4 @@
#include "flash_fwd_launch_template.h"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
half_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
0 → 100644
View file @
908511b2
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_splitkv_dispatch
<
cutlass
::
bfloat16_t
,
160
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment