Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9d7fd5b6
Commit
9d7fd5b6
authored
Oct 03, 2022
by
Eric Engelhart
Browse files
Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates
parent
0c01568d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
4 deletions
+32
-4
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+3
-2
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+2
-2
csrc/flash_attn/src/fp16_switch.h
csrc/flash_attn/src/fp16_switch.h
+27
-0
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
9d7fd5b6
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
*/
*/
#include "static_switch.h"
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
...
@@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
...
@@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
}
}
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
BOOL_SWITCH
(
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
// work around for MSVC issue
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
==
16
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
seqlen_k
==
128
)
{
if
(
params
.
seqlen_k
==
128
)
{
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
9d7fd5b6
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
...
@@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
BOOL_SWITCH
(
launch_params
.
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
...
...
csrc/flash_attn/src/fp16_switch.h
0 → 100644
View file @
9d7fd5b6
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// modified from static_switch.h
// because MSVC cannot handle std::conditional with constexpr variable
#pragma once
/// @param COND - a boolean expression to switch by
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// FP16_SWITCH(flag, [&] {
/// some_function(...);
/// });
/// ```
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = std::conditional<true, __nv_bfloat16, __half>::type; \
return __VA_ARGS__(); \
} else { \
using elem_type = std::conditional<true, __nv_bfloat16, __half>::type; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment