* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
* ... ...
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* 0 8 <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
* 1 9 <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
* 2 10
* ...
* 6 14
* 7 15 <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
* ... ...
* 48 56 <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
* 49 57
* ...
* 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
#define CHECK_NAN(data, name) checkNan(data, name " at " STRINGIZE(__LINE__))
#else
#define CHECK_NAN(data, name)
#endif
classGEMMConfig_W4A4{
public:
// BE CAREFUL: weights need to be repacked when the tiling size changes
staticconstexprintBLOCK_M=256;
staticconstexprintBLOCK_N=128;
staticconstexprintWARP_SIZE=32;
staticconstexprintNUM_WARPS=8;
staticconstexprintINSN_M=16;
staticconstexprintINSN_N=16;
staticconstexprintINSN_K=64;
#if 0
using half_t = half;
using half2_t = half2;
#else
usinghalf_t=__nv_bfloat16;
usinghalf2_t=__nv_bfloat162;
#endif
};
classGEMMConfig_W8A8{
public:
staticconstexprintBLOCK_M=256;
staticconstexprintBLOCK_N=128;
staticconstexprintWARP_SIZE=32;
staticconstexprintNUM_WARPS=8;
IMPORT_GEMM_BASE(Config);
staticconstexprintINSN_M=16;
staticconstexprintINSN_N=16;
staticconstexprintINSN_K=32;
usinghalf_t=half;
usinghalf2_t=half2;
};
template<classConfig>
classGEMMBase:publicConfig{
public:
usingConfig::BLOCK_M;
usingConfig::BLOCK_N;
usingConfig::WARP_SIZE;
usingConfig::NUM_WARPS;
usingConfig::INSN_M;
usingConfig::INSN_N;
usingConfig::INSN_K;
usingtypenameConfig::half_t;
usingtypenameConfig::half2_t;
staticconstexprintWARP_M=BLOCK_M/NUM_WARPS;
staticconstexprintWARP_N=BLOCK_N;
staticconstexprintWARP_K=INSN_K;
staticconstexprintWARP_M_TILES=WARP_M/INSN_M;
staticconstexprintWARP_N_TILES=WARP_N/INSN_N;
staticconstexprintWARP_K_TILES=WARP_K/INSN_K;
/**
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
* ... ...
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* 0 8 <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
* 1 9 <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
* 2 10
* ...
* 6 14
* 7 15 <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
* ... ...
* 48 56 <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
* 49 57
* ...
* 54 62
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}