Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a1f49a2b
Commit
a1f49a2b
authored
Jan 06, 2023
by
Tri Dao
Browse files
[Compilation] Change BOOL_SWITCH to fix Windows compilation
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
parent
a668890f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
29 additions
and
24 deletions
+29
-24
csrc/flash_attn/src/fmha_bwd_hdim128.cu
csrc/flash_attn/src/fmha_bwd_hdim128.cu
+1
-1
csrc/flash_attn/src/fmha_bwd_hdim32.cu
csrc/flash_attn/src/fmha_bwd_hdim32.cu
+1
-1
csrc/flash_attn/src/fmha_bwd_hdim64.cu
csrc/flash_attn/src/fmha_bwd_hdim64.cu
+1
-1
csrc/flash_attn/src/fmha_bwd_launch_template.h
csrc/flash_attn/src/fmha_bwd_launch_template.h
+1
-1
csrc/flash_attn/src/fmha_fwd_hdim128.cu
csrc/flash_attn/src/fmha_fwd_hdim128.cu
+1
-1
csrc/flash_attn/src/fmha_fwd_hdim32.cu
csrc/flash_attn/src/fmha_fwd_hdim32.cu
+1
-1
csrc/flash_attn/src/fmha_fwd_hdim64.cu
csrc/flash_attn/src/fmha_fwd_hdim64.cu
+1
-1
csrc/flash_attn/src/fmha_fwd_launch_template.h
csrc/flash_attn/src/fmha_fwd_launch_template.h
+1
-1
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+21
-16
No files found.
csrc/flash_attn/src/fmha_bwd_hdim128.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim128
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
FP16_SWITCH
(
params
.
is_bf16
,
({
FP16_SWITCH
(
params
.
is_bf16
,
(
[
&
]
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}));
...
...
csrc/flash_attn/src/fmha_bwd_hdim32.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim32
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
FP16_SWITCH
(
params
.
is_bf16
,
({
FP16_SWITCH
(
params
.
is_bf16
,
(
[
&
]
{
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
...
...
csrc/flash_attn/src/fmha_bwd_hdim64.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim64
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
FP16_SWITCH
(
params
.
is_bf16
,
({
FP16_SWITCH
(
params
.
is_bf16
,
(
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
...
...
csrc/flash_attn/src/fmha_bwd_launch_template.h
View file @
a1f49a2b
...
...
@@ -61,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
({
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
(
[
&
]
{
auto
kernel
=
params
.
is_causal
?
&
fmha_bwd_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_bwd_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
...
...
csrc/flash_attn/src/fmha_fwd_hdim128.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim128
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
({
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
[
&
]
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}));
...
...
csrc/flash_attn/src/fmha_fwd_hdim32.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim32
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
({
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
[
&
]
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
...
...
csrc/flash_attn/src/fmha_fwd_hdim64.cu
View file @
a1f49a2b
...
...
@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim64
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
({
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
[
&
]
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
...
...
csrc/flash_attn/src/fmha_fwd_launch_template.h
View file @
a1f49a2b
...
...
@@ -56,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
({
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
(
[
&
]
{
auto
kernel
=
launch_params
.
params
.
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fwd_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
true
>
...
...
csrc/flash_attn/src/static_switch.h
View file @
a1f49a2b
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8
#pragma once
...
...
@@ -9,27 +10,31 @@
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, ({
/// BOOL_SWITCH(flag, BoolConst, (
[&]
{
/// some_function<BoolConst>(...);
/// }));
/// ```
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
#define BOOL_SWITCH(COND, CONST_NAME, CODE) \
if (COND) { \
constexpr bool CONST_NAME = true; \
CODE; \
} else { \
constexpr bool CONST_NAME = false; \
CODE; \
#define BOOL_SWITCH(COND, CONST_NAME, F) \
{ \
if (COND) { \
constexpr bool CONST_NAME = true; \
F(); \
} else { \
constexpr bool CONST_NAME = false; \
F(); \
} \
}
// modified from BOOL_SWITCH
// because MSVC cannot handle std::conditional with constexpr variable
#define FP16_SWITCH(COND, CODE) \
if (COND) { \
using elem_type = __nv_bfloat16; \
CODE; \
} else { \
using elem_type = __half; \
CODE; \
} \
#define FP16_SWITCH(COND, F) \
{ \
if (COND) { \
using elem_type = __nv_bfloat16; \
F(); \
} else { \
using elem_type = __half; \
F(); \
} \
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment