Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6a89b2f1
Commit
6a89b2f1
authored
Sep 03, 2023
by
Tri Dao
Browse files
Remove constexpr in launch template to fix CI compilation
parent
97ba7a62
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
10 deletions
+10
-10
csrc/cutlass
csrc/cutlass
+1
-1
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+4
-4
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+2
-2
flash_attn/__init__.py
flash_attn/__init__.py
+1
-1
training/Dockerfile
training/Dockerfile
+2
-2
No files found.
cutlass
@
34fd9805
Compare
3a8f57a3
...
34fd9805
Subproject commit 3
a8f57a3c89cfff7aa686e95f13d9ad850f61898
Subproject commit 3
4fd98056b69fbf7f0929b3f734bb5f00642e2c9
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
6a89b2f1
...
...
@@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
...
...
@@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
auto
kernel_dq
=
&
flash_bwd_convert_dq_kernel
<
Kernel_traits
>
;
if
constexpr
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
if
(
Kernel_traits
::
kSmemdQSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dq
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemdQSize
));
}
...
...
@@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel
<
Kernel_traits
,
Is_dropout
,
IsCausalConst
,
IsEvenNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if
constexpr
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
if
(
smem_size_dq_dk_dv
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size_dq_dk_dv
));
}
...
...
@@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
});
auto
kernel_dkv
=
&
flash_bwd_convert_dkv_kernel
<
Kernel_traits
>
;
if
constexpr
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
if
(
Kernel_traits
::
kSmemKVSize
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel_dkv
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
kSmemKVSize
));
}
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
6a89b2f1
...
...
@@ -46,7 +46,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// Will only return softmax if dropout, to reduce compilation time.
auto
kernel
=
&
flash_fwd_kernel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
,
ReturnSoftmaxConst
&&
Is_dropout
>
;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
if
constexpr
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
...
...
@@ -74,7 +74,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
BOOL_SWITCH
(
is_even_K
,
IsEvenKConst
,
[
&
]
{
auto
kernel
=
&
flash_fwd_splitkv_kernel
<
Kernel_traits
,
Is_causal
,
IsEvenMNConst
,
IsEvenKConst
>
;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if
constexpr
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
...
...
flash_attn/__init__.py
View file @
6a89b2f1
__version__
=
"2.1.2.post
1
"
__version__
=
"2.1.2.post
2
"
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
...
...
training/Dockerfile
View file @
6a89b2f1
...
...
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN
pip
install
git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN
pip
install
flash-attn
==
2.1.2.post
1
RUN
pip
install
flash-attn
==
2.1.2.post
2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN
git clone https://github.com/HazyResearch/flash-attention
\
&&
cd
flash-attention
&&
git checkout v2.1.2.post
1
\
&&
cd
flash-attention
&&
git checkout v2.1.2.post
2
\
&&
cd
csrc/fused_softmax
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/rotary
&&
pip
install
.
&&
cd
../../
\
&&
cd
csrc/xentropy
&&
pip
install
.
&&
cd
../../
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment