Commit 938256dd authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Refactor `load_A_row_major` to follow scale mapping

parent b138d4fd
......@@ -166,67 +166,197 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// - Data is in row major format
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
// template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
// struct load_A_row_major
// {
// __device__ AFragT operator()(AType const* input_ptr)
// {
// // clang-format off
// // Register Mapping for 16x128: ||
// Register Mapping for 32x64:
// // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | ||
// Size | BLOCK_M | BLOCK_M |
// // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M
// | 0 ... 31 | 0 ... 31 |
// // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector ||
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// // Register Element ------------ ------------- ------------ ------------- Element ||
// Register Element ------------ ------------- Element
// // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg
// 0 [0:7] | K0 | K32 | v[0]
// // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg
// 0 [8:15] | K1 | K33 | v[1]
// // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg
// 0 [16:23] | K2 | K34 | v[2]
// // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg
// 0 [24:31] | K3 | K35 | v[3]
// // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg
// 1 [0:7] | K4 | K36 | v[4]
// // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg
// 1 [8:15] | K5 | K37 | v[5]
// // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg
// 1 [16:23] | K6 | K38 | v[6]
// // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg
// 1 [24:31] | K7 | K39 | v[7]
// // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg
// 2 [0:7] | K8 | K40 | v[8]
// // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg
// 2 [8:15] | K9 | K41 | v[9]
// // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg
// 2 [16:23] | K10 | K42 | v[10]
// // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg
// 2 [24:31] | K11 | K43 | v[11]
// // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg
// 3 [0:7] | K12 | K44 | v[12]
// // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg
// 3 [8:15] | K13 | K45 | v[13]
// // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg
// 3 [16:23] | K14 | K46 | v[14]
// // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg
// 3 [24:31] | K15 | K47 | v[15]
// // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg
// 4 [0:7] | K16 | K48 | v[16]
// // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg
// 4 [8:15] | K17 | K49 | v[17]
// // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg
// 4 [16:23] | K18 | K50 | v[18]
// // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg
// 4 [24:31] | K19 | K51 | v[19]
// // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg
// 5 [0:7] | K20 | K52 | v[20]
// // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg
// 5 [8:15] | K21 | K53 | v[21]
// // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg
// 5 [16:23] | K22 | K54 | v[22]
// // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg
// 5 [24:31] | K23 | K55 | v[23]
// // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg
// 6 [0:7] | K24 | K56 | v[24]
// // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg
// 6 [8:15] | K25 | K57 | v[25]
// // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg
// 6 [16:23] | K26 | K58 | v[26]
// // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg
// 6 [24:31] | K27 | K59 | v[27]
// // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg
// 7 [0:7] | K28 | K60 | v[28]
// // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg
// 7 [8:15] | K29 | K61 | v[29]
// // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg
// 7 [16:23] | K30 | K62 | v[30]
// // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg
// 7 [24:31] | K31 | K63 | v[31]
// // clang-format on
// // Here we want to load a BLOCK_M x BLOCK_K block of data.
// static constexpr uint32_t VW = vectorSize(AFragT{});
// // To start the loading process, let's visualize in 2D coords.
// // Each thread will load 32 elements.
// // We need to know where they start, and where the next elements are.
// auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
// (threadIdx.x / BLOCK_M) * VW); // Col
// // Flatten to 1D row_major offsets.
// auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second;
// };
// // BLOCK_K is a stride in A matrix
// auto startOffset = row_major(startCoord2D, BLOCK_K);
// auto const* fragPtr = reinterpret_cast<AFragT const*>(input_ptr + startOffset);
// return *fragPtr;
// }
// }
template <typename AType, typename AFragT, int32_t BLOCK_M, int32_t BLOCK_K>
__device__ AFragT load_A_row_major(AType const* input_ptr)
struct load_A_row_major
{
// clang-format off
__device__ AFragT operator()(AType const* input_ptr)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0]
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1]
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2]
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3]
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4]
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5]
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6]
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7]
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8]
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9]
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10]
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11]
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12]
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13]
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14]
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15]
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16]
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17]
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18]
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19]
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20]
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21]
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22]
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23]
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24]
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25]
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26]
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27]
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28]
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29]
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30]
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31]
// clang-format on
// Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | |
// M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
// Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
// Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
// Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
// Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
// Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
// Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
// Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
// Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
// Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
// Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
// Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
// Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
// Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
// Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
// Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
// Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
// Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
// Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
// Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
// Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
// Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
// Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
// Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
// Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
// Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
// Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
// Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
// Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
// Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
// Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
// Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// clang-format on
static constexpr int32_t WAVE_SIZE = 64;
// Here we want to load from rows of A in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
// each chunk is separated by offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M;
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D =
std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15}
(threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48}
// Here we want to load a BLOCK_M x BLOCK_K block of data.
static constexpr uint32_t VW = vectorSize(AFragT{});
// auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows
auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row
// To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row
(threadIdx.x / BLOCK_M) * VW); // Col
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K);
// auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K);
auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K);
// BLOCK_K is a stride in A matrix
auto startOffset = row_major(startCoord2D, BLOCK_K);
using ARawT = typename scalar_type<AFragT>::type;
using AScalarFragT = vector_type<ARawT, chunk_size>::type;
auto const* fragPtr = reinterpret_cast<AFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
union
{
AFragT frag;
AScalarFragT chunks[2];
} fragA{};
auto* fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset);
fragA.chunks[0] = *fragPtr;
fragPtr = reinterpret_cast<AScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
fragA.chunks[1] = *fragPtr;
return fragA.frag;
}
};
// Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K)
......@@ -246,8 +376,46 @@ template <typename AType,
__device__ AFragT load_mx_A_row_major(AType const* input_ptr,
ScaleType const* scale_ptr,
ScaleFragT& fragX)
{
// clang-format off
// Register Mapping for 16x128: || Register Mapping for 32x64:
// Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | |
// M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | |
// Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale |
// Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------|
// Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) |
// Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) |
// Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) |
// Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) |
// Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) |
// Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) |
// Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) |
// Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) |
// Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) |
// Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) |
// Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) |
// Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) |
// Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) |
// Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) |
// Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) |
// Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) |
// Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) |
// Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) |
// Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) |
// Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) |
// Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) |
// Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) |
// Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) |
// Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) |
// Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) |
// Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) |
// Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) |
// Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) |
// Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) |
// Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) |
// Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) |
// Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) |
// clang-format on
static constexpr uint32_t VW = vectorSize(AFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
......@@ -266,7 +434,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// preserve upper bits obtain 8-bit exponent
fragX = (fragX & 0xFFFFFF00) | (utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF);
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>(input_ptr);
return load_A_row_major<AType, AFragT, BLOCK_M, BLOCK_K>{}(input_ptr);
}
// Define a load function for input B blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment