static_switch.h 2.89 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h

#pragma once

/// @param COND       - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ...       - code to execute for true and false
///
/// Usage:
/// ```
Tri Dao's avatar
Tri Dao committed
12
/// BOOL_SWITCH(flag, BoolConst, [&] {
Tri Dao's avatar
Tri Dao committed
13
///     some_function<BoolConst>(...);
Tri Dao's avatar
Tri Dao committed
14
/// });
Tri Dao's avatar
Tri Dao committed
15
/// ```
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
21
22
23
24
25
#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
    [&] {                                                                            \
        if (COND) {                                                                  \
            constexpr bool CONST_NAME = true;                                        \
            return __VA_ARGS__();                                                    \
        } else {                                                                     \
            constexpr bool CONST_NAME = false;                                       \
            return __VA_ARGS__();                                                    \
        }                                                                            \
    }()
26

Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#define FP16_SWITCH(COND, ...)                     \
    [&] {                                          \
        if (COND) {                                \
            using elem_type = cutlass::half_t;     \
            return __VA_ARGS__();                  \
        } else {                                   \
            using elem_type = cutlass::bfloat16_t; \
            return __VA_ARGS__();                  \
        }                                          \
    }()

#define FWD_HEADDIM_SWITCH(HEADDIM, ...)  \
    [&] {                                 \
        if (HEADDIM <= 32) {              \
            constexpr int kHeadDim = 32;  \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 64) {       \
            constexpr int kHeadDim = 64;  \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 96) {       \
            constexpr int kHeadDim = 96;  \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 128) {      \
            constexpr int kHeadDim = 128; \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 160) {      \
            constexpr int kHeadDim = 160; \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 192) {      \
            constexpr int kHeadDim = 192; \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 224) {      \
            constexpr int kHeadDim = 224; \
            return __VA_ARGS__();         \
        } else if (HEADDIM <= 256) {      \
            constexpr int kHeadDim = 256; \
            return __VA_ARGS__();         \
        }                                 \
    }()