"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "d373a48c9875f1bb43fde05215a42179b453f81c"
Commit 9f8e26f6 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Add row-major C store

parent 6c39e6af
...@@ -487,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32> ...@@ -487,6 +487,129 @@ struct store_C_col_major<CType, CFragT, 32, 32>
} }
}; };
// Define a store function for C
// Size: (BLOCK_M x BLOCK_N)
// ASSUMPTION:
// - We want contiguous BLOCK_N sized row neighbors in register.
// - Data is in row major format
template <typename CType, typename CFragT, int32_t BLOCK_M, int32_t BLOCK_N>
struct store_C_row_major;
// Here we want to store a 16x16 block of data.
//
// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N |
// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 |
// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector
// Register Element ------------ ------------- ------------ -------------- Element
// Reg0 | M0 | M4 | M8 | M12 | v[0]
// Reg1 | M1 | M5 | M9 | M13 | v[1]
// Reg2 | M2 | M6 | M10 | M14 | v[2]
// Reg3 | M3 | M7 | M11 | M15 | v[3]
template <typename CType, typename CFragT>
struct store_C_row_major<CType, CFragT, 16, 16>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16;
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
auto startOffset = row_major(startCoord2D, 16);
auto kOffset = row_major(stepCoord2D, 16);
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
// If you notice carefully, kOffset != 1.
// This means the following is vector is updated with 4 non-contiguous offsets,
// which the compiler will separate into 4 different global_store_dword instructions.
output[startOffset] = cFrag[0]; // v[0] = Reg 0
output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1
output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2
output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3
}
};
// Here we want to store a 32x32 block of data.
// Register Mapping:
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
template <typename CType, typename CFragT>
struct store_C_row_major<CType, CFragT, 32, 32>
{
__device__ void operator()(CType* output, CFragT cFrag)
{
static constexpr uint32_t WAVE_SIZE = 64;
static constexpr uint32_t VW = 4; // This VW is per 'chunk'
static constexpr uint32_t Dim = 32; // BLOCK_N
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// Flatten to 1D row_major offsets.
auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; };
auto startOffset = row_major(startCoord2D, 32);
auto kMinorOffset = row_major(minorStepCoord2D, 32);
auto kMajorOffset = row_major(majorStepCoord2D, 32);
output[startOffset] = cFrag[0]; // v[0] = Reg 0
output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1
output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2
output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3
output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4
output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5
output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6
output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7
output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8
output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9
output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10
output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11
output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12
output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13
output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14
output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15
}
};
template <typename AType, template <typename AType,
typename BType, typename BType,
typename CType, typename CType,
...@@ -581,7 +704,7 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, ...@@ -581,7 +704,7 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]); fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
} }
auto storeC = store_C_col_major<CType, CFragT, BLOCK_M, BLOCK_N>{}; auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC); storeC(c, fragC);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment