// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
*
*
* wscales store order: (pack = 4)
* 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
*
* 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
...
...
@@ -111,24 +110,25 @@ public:
* 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
*
* wscales store order: (pack = 8)
* 0 1 8 9 16 17 24 25 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
* 2 3 10 11 18 19 26 27 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
* 4 5 12 13 20 21 28 29 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
* 6 7 14 15 22 23 30 31 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
*
* 224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
* ...
* 230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
*
* {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k //
* WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
*
* max pack size set to 8 since max load size is 16 bytes / lane
* min pack size set to 2 since shuffle granularity is 32b 2*half
* 55 63 <-- load by lane 31, broadcast to lane {28, 29, 30, 31} (4x)
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
*
* {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k //
* ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}