"vllm/vscode:/vscode.git/clone" did not exist on "de10ff0b7cc757af3d0374d82c1a2130196af496"
static_switch.h 3.62 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#define BOOL_SWITCH(COND, CONST_NAME, ...)      \
  [&] {                                         \
    if (COND) {                                 \
      constexpr static bool CONST_NAME = true;  \
      return __VA_ARGS__();                     \
    } else {                                    \
      constexpr static bool CONST_NAME = false; \
      return __VA_ARGS__();                     \
    }                                           \
  }()

#define OPT_SWITCH(COND, ...)      \
  [&] {                                         \
    if (COND) {                                 \
      constexpr static int opt = 1;  \
      return __VA_ARGS__();                     \
    } else {                                    \
      constexpr static int opt = 2; \
      return __VA_ARGS__();                     \
    }                                           \
  }()

#define NUM_THREADS_SWITCH(NUM_THREAD, ...)    \
  [&] {                                         \
    if (NUM_THREAD == 256) {                   \
      constexpr static int NUM_THREADS = 256;  \
      return __VA_ARGS__();                     \
    } else {                                    \
      constexpr static int NUM_THREADS = 128;  \
      return __VA_ARGS__();                     \
    }                                           \
  }()

  #define HEADSIZE_SWITCH(HEADDIM, ...)   \
  [&] {                                    \
zhangshao's avatar
zhangshao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
    if (HEADDIM == 64) {                   \
      constexpr static int HEAD_SIZE = 64;  \
      return __VA_ARGS__();                \
    } else if (HEADDIM == 80) {            \
      constexpr static int HEAD_SIZE = 80;  \
      return __VA_ARGS__();                \
    } else if (HEADDIM == 96) {            \
      constexpr static int HEAD_SIZE = 96;  \
      return __VA_ARGS__();                \
    } else if (HEADDIM == 112) {           \
      constexpr static int HEAD_SIZE = 112; \
      return __VA_ARGS__();                \
    } else if (HEADDIM == 128) {           \
zhangshao's avatar
zhangshao committed
49
50
      constexpr static int HEAD_SIZE = 128; \
      return __VA_ARGS__();                \
zhangshao's avatar
zhangshao committed
51
52
53
54
55
    } else if (HEADDIM == 256) {           \
      constexpr static int HEAD_SIZE = 256; \
      return __VA_ARGS__();                \
    }                                      \
    else {                                 \
zhangshao's avatar
zhangshao committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
      TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
    }                                      \
  }()

#define REUSEKV_SWITCH(num_blocks , ...)      \
[&] {                                                   \
    if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){      \
        constexpr static int REUSE_KV_TIMES = 4;        \
        return __VA_ARGS__();                           \
    } else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
        constexpr static int REUSE_KV_TIMES = 2;        \
        return __VA_ARGS__();                           \
    } else {                                            \
        constexpr static int REUSE_KV_TIMES = 1;        \
        return __VA_ARGS__();                           \
    }                                                   \
}()

#define REUSEKV_SWITCH_V1(num_blocks , ...)      \
[&] {                                                   \
    if (num_heads > num_kv_heads && num_blocks >= 1200){      \
        constexpr static int REUSE_KV_TIMES = 2;        \
        return __VA_ARGS__();                           \
    }  else {                                           \
        constexpr static int REUSE_KV_TIMES = 1;        \
        return __VA_ARGS__();                           \
    }                                                   \
}()