Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8a2ece89
Commit
8a2ece89
authored
Dec 06, 2022
by
Tri Dao
Browse files
Simplify BOOL_SWITCH macro to fix compiling error on gcc 7
parent
a84d0728
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
40 additions
and
61 deletions
+40
-61
csrc/flash_attn/src/fmha.h
csrc/flash_attn/src/fmha.h
+2
-1
csrc/flash_attn/src/fmha_bwd_hdim128.cu
csrc/flash_attn/src/fmha_bwd_hdim128.cu
+2
-3
csrc/flash_attn/src/fmha_bwd_hdim32.cu
csrc/flash_attn/src/fmha_bwd_hdim32.cu
+2
-3
csrc/flash_attn/src/fmha_bwd_hdim64.cu
csrc/flash_attn/src/fmha_bwd_hdim64.cu
+2
-3
csrc/flash_attn/src/fmha_bwd_launch_template.h
csrc/flash_attn/src/fmha_bwd_launch_template.h
+2
-3
csrc/flash_attn/src/fmha_fwd_hdim128.cu
csrc/flash_attn/src/fmha_fwd_hdim128.cu
+2
-2
csrc/flash_attn/src/fmha_fwd_hdim32.cu
csrc/flash_attn/src/fmha_fwd_hdim32.cu
+2
-2
csrc/flash_attn/src/fmha_fwd_hdim64.cu
csrc/flash_attn/src/fmha_fwd_hdim64.cu
+2
-2
csrc/flash_attn/src/fmha_fwd_launch_template.h
csrc/flash_attn/src/fmha_fwd_launch_template.h
+2
-3
csrc/flash_attn/src/fp16_switch.h
csrc/flash_attn/src/fp16_switch.h
+0
-27
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+22
-12
No files found.
csrc/flash_attn/src/fmha.h
View file @
8a2ece89
...
...
@@ -36,7 +36,8 @@
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/UnpackRaw.cuh>
#include <fmha_utils.h>
...
...
csrc/flash_attn/src/fmha_bwd_hdim128.cu
View file @
8a2ece89
...
...
@@ -5,9 +5,8 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim128
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
params
.
is_bf16
,
({
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
});
})
)
;
}
\ No newline at end of file
csrc/flash_attn/src/fmha_bwd_hdim32.cu
View file @
8a2ece89
...
...
@@ -5,8 +5,7 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim32
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
params
.
is_bf16
,
({
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
...
...
@@ -14,5 +13,5 @@ void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
});
})
)
;
}
\ No newline at end of file
csrc/flash_attn/src/fmha_bwd_hdim64.cu
View file @
8a2ece89
...
...
@@ -5,8 +5,7 @@
#include "fmha_bwd_launch_template.h"
void
run_fmha_bwd_hdim64
(
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
,
const
bool
configure
)
{
// work around for MSVC issue
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
params
.
is_bf16
,
({
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
...
...
@@ -27,5 +26,5 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
}
});
})
)
;
}
\ No newline at end of file
csrc/flash_attn/src/fmha_bwd_launch_template.h
View file @
8a2ece89
...
...
@@ -3,7 +3,6 @@
#pragma once
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
...
...
@@ -62,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
bool
is_dropout
=
params
.
p_dropout
<
1.
f
;
// params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
(
{
auto
kernel
=
params
.
is_causal
?
&
fmha_bwd_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
>
:
&
fmha_bwd_dq_dk_dv_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
false
>
;
...
...
@@ -111,5 +110,5 @@ void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const boo
kernel_seqparallel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size_dq_dk_dv
,
stream
>>>
(
params
);
}
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
})
)
;
}
csrc/flash_attn/src/fmha_fwd_hdim128.cu
View file @
8a2ece89
...
...
@@ -5,8 +5,8 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim128
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
128
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
});
})
)
;
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fwd_hdim32.cu
View file @
8a2ece89
...
...
@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim32
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
...
...
@@ -13,5 +13,5 @@ void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
32
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
});
})
)
;
}
\ No newline at end of file
csrc/flash_attn/src/fmha_fwd_hdim64.cu
View file @
8a2ece89
...
...
@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h"
void
run_fmha_fwd_hdim64
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
)
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
(
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
...
...
@@ -13,5 +13,5 @@ void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
,
elem_type
>
;
run_fmha_fwd_loop
<
Kernel_traits
>
(
launch_params
);
}
});
})
)
;
}
csrc/flash_attn/src/fmha_fwd_launch_template.h
View file @
8a2ece89
...
...
@@ -8,7 +8,6 @@
#include <cuda_bf16.h>
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
...
...
@@ -57,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
launch_params
.
is_dropout
,
IsDropoutConst
,
(
{
auto
kernel
=
launch_params
.
params
.
is_causal
?
(
launch_params
.
return_softmax
?
&
fmha_fwd_loop_kernel
<
Kernel_traits
,
IsDropoutConst
,
true
,
true
>
...
...
@@ -88,5 +87,5 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
launch_params
.
stream
>>>
(
launch_params
.
params
);
FMHA_CHECK_CUDA
(
cudaPeekAtLastError
());
});
})
)
;
}
csrc/flash_attn/src/fp16_switch.h
deleted
100644 → 0
View file @
a84d0728
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// modified from static_switch.h
// because MSVC cannot handle std::conditional with constexpr variable
#pragma once
/// @param COND - a boolean expression to switch by
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// FP16_SWITCH(flag, [&] {
/// some_function(...);
/// });
/// ```
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
using elem_type = __half; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
csrc/flash_attn/src/static_switch.h
View file @
8a2ece89
...
...
@@ -9,17 +9,27 @@
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst,
[&]
{
/// BOOL_SWITCH(flag, BoolConst,
(
{
/// some_function<BoolConst>(...);
/// });
/// })
)
;
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
#define BOOL_SWITCH(COND, CONST_NAME, CODE) \
if (COND) { \
constexpr bool CONST_NAME = true; \
CODE; \
} else { \
constexpr bool CONST_NAME = false; \
CODE; \
}
// modified from BOOL_SWITCH
// because MSVC cannot handle std::conditional with constexpr variable
#define FP16_SWITCH(COND, CODE) \
if (COND) { \
using elem_type = __nv_bfloat16; \
CODE; \
} else { \
using elem_type = __half; \
CODE; \
} \
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment