Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
31ae2488
Unverified
Commit
31ae2488
authored
Jul 19, 2023
by
Tri Dao
Committed by
GitHub
Jul 19, 2023
Browse files
Merge pull request #343 from danthe3rd/if_constexpr
Fix compile error with `BOOL_SWITCH`
parents
d1a3b52f
538d570c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
49 deletions
+50
-49
csrc/flash_attn/src/static_switch.h
csrc/flash_attn/src/static_switch.h
+50
-49
No files found.
csrc/flash_attn/src/static_switch.h
View file @
31ae2488
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
...
...
@@ -13,53 +14,53 @@
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...)
\
[&] {
\
if (COND) {
\
constexpr bool CONST_NAME = true;
\
return __VA_ARGS__();
\
} else {
\
constexpr bool CONST_NAME = false;
\
return __VA_ARGS__();
\
}
\
}()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr
static
bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr
static
bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...)
\
[&] {
\
if (COND) {
\
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else {
\
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
}
\
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr int kHeadDim = 256; \
return __VA_ARGS__(); \
}
\
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...)
\
[&] {
\
if (HEADDIM <= 32) {
\
constexpr
static
int kHeadDim = 32; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 64) {
\
constexpr
static
int kHeadDim = 64; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 96) {
\
constexpr
static
int kHeadDim = 96; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 128) {
\
constexpr
static
int kHeadDim = 128; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 160) {
\
constexpr
static
int kHeadDim = 160; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 192) {
\
constexpr
static
int kHeadDim = 192; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 224) {
\
constexpr
static
int kHeadDim = 224; \
return __VA_ARGS__();
\
} else if (HEADDIM <= 256) {
\
constexpr
static
int kHeadDim = 256; \
return __VA_ARGS__();
\
}
\
}()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment