Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhangdong1
Block-Sparse-Attention
Commits
99511c34
Commit
99511c34
authored
Mar 21, 2025
by
sxtyzhangzk
Browse files
[major] fix compile error when sm_75 is enabled in nvcc
parent
0d23f715
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
0 deletions
+30
-0
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
+30
-0
No files found.
csrc/block_sparse_attn/src/flash_fwd_launch_template.h
View file @
99511c34
...
...
@@ -32,27 +32,57 @@ using namespace pytorch_compat;
#include "flash.h"
#include "flash_fwd_kernel.h"
__device__
__forceinline__
static
void
trap_unsupported_arch
()
{
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"This kernel is not supported on your GPU
\n
"
);
}
__syncthreads
();
__nanosleep
(
1000000
);
__trap
();
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
Flash_fwd_params
params
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
return
;
#else
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Is_exact_streaming
>
__global__
void
flash_fwd_block_kernel
(
Flash_fwd_params
params
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
return
;
#else
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_block_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
,
Is_exact_streaming
>
(
params
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
Flash_fwd_params
params
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
return
;
#else
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
#endif
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
Flash_fwd_params
params
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
return
;
#else
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment