Commit a1f49a2b authored by Tri Dao's avatar Tri Dao
Browse files

[Compilation] Change BOOL_SWITCH to fix Windows compilation

Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
parent a668890f
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h" #include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) { void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ({ FP16_SWITCH(params.is_bf16, ([&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
})); }));
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h" #include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) { void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ({ FP16_SWITCH(params.is_bf16, ([&] {
if (params.seqlen_k == 128) { if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_bwd_launch_template.h" #include "fmha_bwd_launch_template.h"
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) { void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
FP16_SWITCH(params.is_bf16, ({ FP16_SWITCH(params.is_bf16, ([&] {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.seqlen_k == 128) { if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
......
...@@ -61,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const boo ...@@ -61,7 +61,7 @@ void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const boo
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH(is_dropout, IsDropoutConst, ({ BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] {
auto kernel = params.is_causal auto kernel = params.is_causal
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true> ? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>; : &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h" #include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) { void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, ({ FP16_SWITCH(launch_params.params.is_bf16, ([&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params); run_fmha_fwd_loop<Kernel_traits>(launch_params);
})); }));
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h" #include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) { void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, ({ FP16_SWITCH(launch_params.params.is_bf16, ([&] {
if (launch_params.params.seqlen_k == 128) { if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params); run_fmha_fwd_loop<Kernel_traits>(launch_params);
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "fmha_fwd_launch_template.h" #include "fmha_fwd_launch_template.h"
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) { void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, ({ FP16_SWITCH(launch_params.params.is_bf16, ([&] {
if (launch_params.params.seqlen_k == 128) { if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fwd_loop<Kernel_traits>(launch_params); run_fmha_fwd_loop<Kernel_traits>(launch_params);
......
...@@ -56,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) { ...@@ -56,7 +56,7 @@ void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21 // https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ({ BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] {
auto kernel = launch_params.params.is_causal auto kernel = launch_params.params.is_causal
? (launch_params.return_softmax ? (launch_params.return_softmax
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true> ? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
......
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8
#pragma once #pragma once
...@@ -9,27 +10,31 @@ ...@@ -9,27 +10,31 @@
/// ///
/// Usage: /// Usage:
/// ``` /// ```
/// BOOL_SWITCH(flag, BoolConst, ({ /// BOOL_SWITCH(flag, BoolConst, ([&] {
/// some_function<BoolConst>(...); /// some_function<BoolConst>(...);
/// })); /// }));
/// ``` /// ```
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro. /// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
#define BOOL_SWITCH(COND, CONST_NAME, CODE) \ #define BOOL_SWITCH(COND, CONST_NAME, F) \
if (COND) { \ { \
constexpr bool CONST_NAME = true; \ if (COND) { \
CODE; \ constexpr bool CONST_NAME = true; \
} else { \ F(); \
constexpr bool CONST_NAME = false; \ } else { \
CODE; \ constexpr bool CONST_NAME = false; \
F(); \
} \
} }
// modified from BOOL_SWITCH // modified from BOOL_SWITCH
// because MSVC cannot handle std::conditional with constexpr variable // because MSVC cannot handle std::conditional with constexpr variable
#define FP16_SWITCH(COND, CODE) \ #define FP16_SWITCH(COND, F) \
if (COND) { \ { \
using elem_type = __nv_bfloat16; \ if (COND) { \
CODE; \ using elem_type = __nv_bfloat16; \
} else { \ F(); \
using elem_type = __half; \ } else { \
CODE; \ using elem_type = __half; \
} \ F(); \
} \
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment