#define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #define NUM_THREADS_SWITCH(NUM_THREAD, ...) \ [&] { \ if (NUM_THREAD == 256) { \ constexpr static int NUM_THREADS = 256; \ return __VA_ARGS__(); \ }else if (NUM_THREAD == 128) { \ constexpr static int NUM_THREADS = 128; \ return __VA_ARGS__(); \ } else { \ constexpr static int NUM_THREADS = 64; \ return __VA_ARGS__(); \ } \ }() #define HEADSIZE_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM == 64) { \ constexpr static int HEAD_SIZE = 64; \ return __VA_ARGS__(); \ } else if (HEADDIM == 80) { \ constexpr static int HEAD_SIZE = 80; \ return __VA_ARGS__(); \ } else if (HEADDIM == 96) { \ constexpr static int HEAD_SIZE = 96; \ return __VA_ARGS__(); \ } else if (HEADDIM == 112) { \ constexpr static int HEAD_SIZE = 112; \ return __VA_ARGS__(); \ } else if (HEADDIM == 128) { \ constexpr static int HEAD_SIZE = 128; \ return __VA_ARGS__(); \ } else if (HEADDIM == 256) { \ constexpr static int HEAD_SIZE = 256; \ return __VA_ARGS__(); \ } \ else { \ TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\ } \ }() #define REUSEKV_SWITCH(reusekv,...) \ [&] { \ if (reusekv==16){ \ constexpr static int REUSE_KV_TIMES = 16; \ return __VA_ARGS__();} \ else if (reusekv==8){ \ constexpr static int REUSE_KV_TIMES = 8; \ return __VA_ARGS__(); \ }else if (reusekv==4){ \ constexpr static int REUSE_KV_TIMES = 4; \ return __VA_ARGS__(); \ }else if (reusekv==2){ \ constexpr static int REUSE_KV_TIMES = 2; \ return __VA_ARGS__(); \ }else { \ constexpr static int REUSE_KV_TIMES = 1; \ return __VA_ARGS__(); \ } \ }() #define USEVMAC_SWITCH_V1(num_blocks , ...) \ [&] { \ if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \ constexpr static int use_vmac = false; \ return __VA_ARGS__(); \ } else { \ constexpr static int use_vmac = true; \ return __VA_ARGS__(); \ } \ }()