Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4a73e903
Unverified
Commit
4a73e903
authored
Mar 15, 2024
by
Driss Guessous
Committed by
GitHub
Mar 15, 2024
Browse files
Add in, macrosf for defining __grid_constant__ (#852)
parent
2a15840f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
21 deletions
+65
-21
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+1
-1
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+34
-11
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+30
-9
No files found.
csrc/flash_attn/flash_api.cpp
View file @
4a73e903
...
...
@@ -46,7 +46,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
bool
seqlenq_ngroups_swapped
=
false
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
))
;
params
=
{}
;
params
.
is_bf16
=
q
.
dtype
()
==
torch
::
kBFloat16
;
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
4a73e903
...
...
@@ -11,6 +11,40 @@
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL
(
flash_bwd_dq_dk_dv_loop_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
)
{
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_M
,
Is_even_K
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL
(
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
)
{
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template
<
bool
Clear_dQaccum
=
true
,
typename
Kernel_traits
>
__global__
void
flash_bwd_dot_do_o_kernel
(
const
Flash_bwd_params
params
)
{
flash
::
compute_dot_do_o
<
Clear_dQaccum
,
Kernel_traits
>
(
params
);
...
...
@@ -21,17 +55,6 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash
::
clear_dKVaccum
<
Kernel_traits
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Has_alibi
,
bool
Is_even_M
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_kernel
(
__grid_constant__
const
Flash_bwd_params
params
)
{
flash
::
compute_dq_dk_dv
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Has_alibi
,
Is_even_M
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
>
__global__
void
flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel
(
__grid_constant__
const
Flash_bwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_dq_dk_dv_seqk_parallel
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
>
(
params
);
}
template
<
typename
Kernel_traits
>
__global__
void
flash_bwd_convert_dq_kernel
(
const
Flash_bwd_params
params
,
const
int
nsplits
)
{
flash
::
convert_dQ
<
Kernel_traits
>
(
params
,
nsplits
);
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
4a73e903
...
...
@@ -10,19 +10,40 @@
#include "flash.h"
#include "flash_fwd_kernel.h"
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
>
__global__
void
flash_fwd_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_kernel
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
)
{
#if defined(ARCH_SUPPORTS_FLASH)
static_assert
(
!
(
Is_causal
&&
Is_local
));
// Enforce constraints
flash
::
compute_attn
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Return_softmax
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
>
__global__
void
flash_fwd_splitkv_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Split
,
bool
Append_KV
)
{
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_attn_splitkv
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Split
,
Append_KV
>
(
params
);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template
<
typename
Kernel_traits
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
>
__global__
void
flash_fwd_splitkv_combine_kernel
(
__grid_constant__
const
Flash_fwd_params
params
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_combine_kernel
,
int
kBlockM
,
int
Log_max_splits
,
bool
Is_even_K
)
{
static_assert
(
Log_max_splits
>=
1
);
flash
::
combine_attn_seqk_parallel
<
Kernel_traits
,
kBlockM
,
Log_max_splits
,
Is_even_K
>
(
params
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment