Commit 30a1206c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Refactor `load_B_col_major` to follow scale mapping

parent 1bcc08ff
...@@ -271,11 +271,7 @@ struct load_A_row_major ...@@ -271,11 +271,7 @@ struct load_A_row_major
// Define a load function for scaled A blocks: // Define a load function for scaled A blocks:
// Size: (BLOCK_M x BLOCK_K) // Size: (BLOCK_M x BLOCK_K)
// ASSUMPTION: // ASSUMPTION:
// - We want contiguous BLOCK_M sized column neighbors in register.
// - Data is in row major format
// - The scale inputs distributed across 64 lanes. // - The scale inputs distributed across 64 lanes.
// This means:
// - From A we will load BLOCK_M rows of size K to satisfy our input data
template <typename AType, template <typename AType,
typename AFragT, typename AFragT,
typename ScaleType, typename ScaleType,
...@@ -349,80 +345,99 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, ...@@ -349,80 +345,99 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr,
// Define a load function for input B blocks: // Define a load function for input B blocks:
// Size: (BLOCK_K x BLOCK_N) // Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION: // - Data is in col major format
// - We want contiguous BLOCK_N sized row neighbors in register. // - Cols are loaded in contiguous chunks that map to corresponding microscales
// - Data is in column major format // - Each col is loaded in chunks of size 16 and each thread loads 32 elements
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N> template <typename BType, typename BFragT, int32_t BLOCK_K, int32_t BLOCK_N>
__device__ BFragT load_B_col_major(BType const* input_ptr) __device__ BFragT load_B_col_major(BType const* input_ptr)
{ {
// clang-format off // clang-format off
// Register Mapping for 128x16: || Register Mapping for 64x32: // Register Mapping for 128x16: || Register Mapping for 64x32:
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 | // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element|
// Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------|
// Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] |
// Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] |
// Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] |
// Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] |
// Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] |
// Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] |
// Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] |
// Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] |
// Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] |
// Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] |
// Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] |
// Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] |
// Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] |
// Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] |
// Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] |
// Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] |
// Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] |
// Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] |
// Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] |
// Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] |
// Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] |
// Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] |
// Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] |
// Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] |
// Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] |
// Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] |
// Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] |
// Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] |
// Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] |
// Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] |
// Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] |
// Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] |
// clang-format on // clang-format on
// Here we want to load a BLOCK_K x BLOCK_N block of data. static constexpr int32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = vectorSize(BFragT{});
// Here we want to load from cols of B in chunks of 16 elements each.
static constexpr uint32_t chunk_size = 16;
// each chunk is separated by an offset
static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64
// To start the loading process, let's visualize in 2D coords. // To start the loading process, let's visualize in 2D coords.
// Each thread will load 32 elements. // Each thread will load 32 elements.
// We need to know where they start, and where the next elements are. // We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row auto startCoord2D =
threadIdx.x % BLOCK_N); // Col std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48}
threadIdx.x % BLOCK_N); // Col {0-31} | {0-15}
// Flatten to 1D col_major offsets. // Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
// auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols
auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col
// BLOCK_K is a stride in B matrix
auto startOffset = col_major(startCoord2D, BLOCK_K); auto startOffset = col_major(startCoord2D, BLOCK_K);
// auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K);
auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K);
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset); using BRawT = typename scalar_type<BFragT>::type;
return *fragPtr; using BScalarFragT = vector_type<BRawT, chunk_size>::type;
union
{
BFragT frag;
BScalarFragT chunks[2];
} fragB{};
auto* fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset);
fragB.chunks[0] = *fragPtr;
fragPtr = reinterpret_cast<BScalarFragT const*>(input_ptr + startOffset + kMajorOffset);
fragB.chunks[1] = *fragPtr;
return fragB.frag;
} }
// Define a load function for scaled B blocks: // Define a load function for scaled B blocks:
// Size: (BLOCK_K x BLOCK_N) // Size: (BLOCK_K x BLOCK_N)
// ASSUMPTION: // ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in column major format
// - The scale inputs distributed across 64 lanes. // - The scale inputs distributed across 64 lanes.
// This means:
// - From B we will load K rows of size BLOCK_N to satisfy our input data
template <typename BType, template <typename BType,
typename BFragT, typename BFragT,
typename ScaleType, typename ScaleType,
...@@ -435,6 +450,46 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, ...@@ -435,6 +450,46 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr,
ScaleFragT& fragX) ScaleFragT& fragX)
{ {
// clang-format off
// Register Mapping for 128x16: || Register Mapping for 64x32:
// Size | BLOCK_N | BLOCK_N | | BLOCK_N | BLOCK_N | | || Size | BLOCK_N | BLOCK_N | | |
// N | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || N | 0 ... 31 | 0 ... 31 | Vector | |
// Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale |
// Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------|
// Reg 0 [0:7] | K0 | K16 | x(0,N) | K32 | K48 | x(1,N) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(0,N) |
// Reg 0 [8:15] | K1 | K17 | x(0,N) | K33 | K49 | x(1,N) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(0,N) |
// Reg 0 [16:23] | K2 | K18 | x(0,N) | K34 | K50 | x(1,N) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(0,N) |
// Reg 0 [24:31] | K3 | K19 | x(0,N) | K35 | K51 | x(1,N) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(0,N) |
// Reg 1 [0:7] | K4 | K20 | x(0,N) | K36 | K52 | x(1,N) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(0,N) |
// Reg 1 [8:15] | K5 | K21 | x(0,N) | K37 | K53 | x(1,N) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(0,N) |
// Reg 1 [16:23] | K6 | K22 | x(0,N) | K38 | K54 | x(1,N) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(0,N) |
// Reg 1 [24:31] | K7 | K23 | x(0,N) | K39 | K55 | x(1,N) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(0,N) |
// Reg 2 [0:7] | K8 | K24 | x(0,N) | K40 | K56 | x(1,N) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(0,N) |
// Reg 2 [8:15] | K9 | K25 | x(0,N) | K41 | K57 | x(1,N) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(0,N) |
// Reg 2 [16:23] | K10 | K26 | x(0,N) | K42 | K58 | x(1,N) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(0,N) |
// Reg 2 [24:31] | K11 | K27 | x(0,N) | K43 | K59 | x(1,N) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(0,N) |
// Reg 3 [0:7] | K12 | K28 | x(0,N) | K44 | K60 | x(1,N) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(0,N) |
// Reg 3 [8:15] | K13 | K29 | x(0,N) | K45 | K61 | x(1,N) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(0,N) |
// Reg 3 [16:23] | K14 | K30 | x(0,N) | K46 | K62 | x(1,N) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(0,N) |
// Reg 3 [24:31] | K15 | K31 | x(0,N) | K47 | K63 | x(1,N) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(0,N) |
// Reg 4 [0:7] | K64 | K80 | x(2,N) | K96 | K112 | x(3,N) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(1,N) |
// Reg 4 [8:15] | K65 | K81 | x(2,N) | K97 | K113 | x(3,N) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(1,N) |
// Reg 4 [16:23] | K66 | K82 | x(2,N) | K98 | K114 | x(3,N) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(1,N) |
// Reg 4 [24:31] | K67 | K83 | x(2,N) | K99 | K115 | x(3,N) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(1,N) |
// Reg 5 [0:7] | K68 | K84 | x(2,N) | K100 | K116 | x(3,N) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(1,N) |
// Reg 5 [8:15] | K69 | K85 | x(2,N) | K101 | K117 | x(3,N) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(1,N) |
// Reg 5 [16:23] | K70 | K86 | x(2,N) | K102 | K118 | x(3,N) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(1,N) |
// Reg 5 [24:31] | K71 | K87 | x(2,N) | K103 | K119 | x(3,N) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(1,N) |
// Reg 6 [0:7] | K72 | K88 | x(2,N) | K104 | K120 | x(3,N) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(1,N) |
// Reg 6 [8:15] | K73 | K89 | x(2,N) | K105 | K121 | x(3,N) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(1,N) |
// Reg 6 [16:23] | K74 | K90 | x(2,N) | K106 | K122 | x(3,N) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(1,N) |
// Reg 6 [24:31] | K75 | K91 | x(2,N) | K107 | K123 | x(3,N) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(1,N) |
// Reg 7 [0:7] | K76 | K92 | x(2,N) | K108 | K124 | x(3,N) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(1,N) |
// Reg 7 [8:15] | K77 | K93 | x(2,N) | K109 | K125 | x(3,N) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(1,N) |
// Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) |
// Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) |
// clang-format on
static constexpr uint32_t VW = vectorSize(BFragT{}); static constexpr uint32_t VW = vectorSize(BFragT{});
static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment