Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
88dc2040
Unverified
Commit
88dc2040
authored
Oct 04, 2022
by
Tri Dao
Committed by
GitHub
Oct 04, 2022
Browse files
Merge pull request #52 from bob80333/main
Make flash attention compile on Windows.
parents
0c01568d
2211db5f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
5 deletions
+34
-5
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+3
-2
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+2
-2
csrc/flash_attn/src/fp16_switch.h
csrc/flash_attn/src/fp16_switch.h
+27
-0
setup.py
setup.py
+2
-1
No files found.
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
88dc2040
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
*/
*/
#include "static_switch.h"
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
...
@@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
...
@@ -52,8 +53,8 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_
}
}
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
void
run_fmha_dgrad_fp16_sm80
(
const
FMHA_dgrad_params
&
params
,
cudaStream_t
stream
)
{
BOOL_SWITCH
(
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
// work around for MSVC issue
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
FP16_SWITCH
(
params
.
is_bf16
,
[
&
]
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
params
.
d
==
16
)
{
if
(
params
.
d
==
16
)
{
if
(
params
.
seqlen_k
==
128
)
{
if
(
params
.
seqlen_k
==
128
)
{
...
...
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
88dc2040
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include "static_switch.h"
#include "static_switch.h"
#include "fp16_switch.h"
#include "fmha.h"
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_fprop_kernel_1xN.h"
...
@@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
...
@@ -83,8 +84,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
void
run_fmha_fp16_sm80
(
Launch_params
<
FMHA_fprop_params
>
&
launch_params
,
const
bool
configure
)
{
const
bool
configure
)
{
BOOL_SWITCH
(
launch_params
.
params
.
is_bf16
,
IsBf16Const
,
[
&
]
{
FP16_SWITCH
(
launch_params
.
params
.
is_bf16
,
[
&
]
{
using
elem_type
=
std
::
conditional
<
IsBf16Const
,
__nv_bfloat16
,
__half
>::
type
;
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
d
==
16
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
if
(
launch_params
.
params
.
seqlen_k
==
128
)
{
...
...
csrc/flash_attn/src/fp16_switch.h
0 → 100644
View file @
88dc2040
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// modified from static_switch.h
// because MSVC cannot handle std::conditional with constexpr variable
#pragma once
/// @param COND - a boolean expression to switch by
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// FP16_SWITCH(flag, [&] {
/// some_function(...);
/// });
/// ```
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = __nv_bfloat16; \
return __VA_ARGS__(); \
} else { \
using elem_type = __half; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
setup.py
View file @
88dc2040
...
@@ -125,10 +125,11 @@ ext_modules.append(
...
@@ -125,10 +125,11 @@ ext_modules.append(
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
(
"nvcc"
:
append_nvcc_threads
(
[
[
"-O3"
,
"-O3"
,
"-std=c++17"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-relaxed-constexpr"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment