Commit 2bd601e1 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Cleanup

parent 718c7abb
...@@ -30,14 +30,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16> ...@@ -30,14 +30,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
{ {
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{ {
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
} }
}; };
...@@ -46,14 +40,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32> ...@@ -46,14 +40,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{ {
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{ {
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{}; auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
} }
}; };
...@@ -131,43 +119,9 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) ...@@ -131,43 +119,9 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// BLOCK_M is a stride in A matrix // BLOCK_M is a stride in A matrix
auto startOffset = col_major(startCoord2D, BLOCK_M); auto startOffset = col_major(startCoord2D, BLOCK_M);
auto kOffset = col_major(stepCoord2D, BLOCK_M); auto kOffset = col_major(stepCoord2D, BLOCK_M);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector // kOffset == BLOCK_M
#if 0 // This means every BLOCK_M element is loaded into output vector
auto fragA = AScalarFragT{
bit_cast<ARawT>(input_ptr[startOffset]), // XXX v[0] = Reg 0 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 1 * kOffset]), // XXX v[1] = Reg 0 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 2 * kOffset]), // XXX v[2] = Reg 0 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 3 * kOffset]), // XXX v[3] = Reg 0 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 4 * kOffset]), // XXX v[4] = Reg 1 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 5 * kOffset]), // XXX v[5] = Reg 1 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 6 * kOffset]), // XXX v[6] = Reg 1 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 7 * kOffset]), // XXX v[7] = Reg 1 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 8 * kOffset]), // XXX v[8] = Reg 2 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 9 * kOffset]), // XXX v[9] = Reg 2 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 10 * kOffset]), // XXX v[10] = Reg 2 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 11 * kOffset]), // XXX v[11] = Reg 2 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 12 * kOffset]), // XXX v[12] = Reg 3 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 13 * kOffset]), // XXX v[13] = Reg 3 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 14 * kOffset]), // XXX v[14] = Reg 3 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 15 * kOffset]), // XXX v[15] = Reg 3 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 16 * kOffset]), // XXX v[16] = Reg 4 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 17 * kOffset]), // XXX v[17] = Reg 4 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 18 * kOffset]), // XXX v[18] = Reg 4 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 19 * kOffset]), // XXX v[19] = Reg 4 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 20 * kOffset]), // XXX v[20] = Reg 5 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 21 * kOffset]), // XXX v[21] = Reg 5 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 22 * kOffset]), // XXX v[22] = Reg 5 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 23 * kOffset]), // XXX v[23] = Reg 5 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 24 * kOffset]), // XXX v[24] = Reg 6 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 25 * kOffset]), // XXX v[25] = Reg 6 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 26 * kOffset]), // XXX v[26] = Reg 6 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 27 * kOffset]), // XXX v[27] = Reg 6 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 28 * kOffset]), // XXX v[28] = Reg 7 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 29 * kOffset]), // XXX v[29] = Reg 7 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 30 * kOffset]), // XXX v[30] = Reg 7 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 31 * kOffset])}; // XXX v[31] = Reg 7 [24:31]
#else
auto fragA = AScalarFragT{}; auto fragA = AScalarFragT{};
#pragma unroll VW #pragma unroll VW
for(uint32_t i = 0; i < VW; i++) for(uint32_t i = 0; i < VW; i++)
...@@ -175,7 +129,6 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) ...@@ -175,7 +129,6 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
fragA[i] = bit_cast<ARawT>(input_ptr[startOffset + i * kOffset]); fragA[i] = bit_cast<ARawT>(input_ptr[startOffset + i * kOffset]);
} }
#endif
return fragA; return fragA;
} }
...@@ -237,15 +190,12 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) ...@@ -237,15 +190,12 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// We need to know where they start, and where the next elements are. // We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
threadIdx.x % BLOCK_N); // Col threadIdx.x % BLOCK_N); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets. // Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, BLOCK_K); auto startOffset = col_major(startCoord2D, BLOCK_K);
// auto kOffset = col_major(stepCoord2D, BLOCK_K);
// kOffset == 1
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset); auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
return *fragPtr; return *fragPtr;
} }
...@@ -278,29 +228,16 @@ struct store_C_col_major<CType, CFragT, 16, 16> ...@@ -278,29 +228,16 @@ struct store_C_col_major<CType, CFragT, 16, 16>
static constexpr uint32_t VW = vectorSize(cFrag); // 4 static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16; static constexpr uint32_t Dim = 16;
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
// Each thread will load 4 elements. // Each thread will load 4 elements.
// We need to know where they start, and where the next elements are. // We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col threadIdx.x % Dim); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets. // Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 16); auto startOffset = col_major(startCoord2D, 16);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset); auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag; *fragPtr = cFrag;
} }
...@@ -343,22 +280,9 @@ struct store_C_col_major<CType, CFragT, 32, 32> ...@@ -343,22 +280,9 @@ struct store_C_col_major<CType, CFragT, 32, 32>
static constexpr uint32_t Dim = 32; static constexpr uint32_t Dim = 32;
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8 static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks' // Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0); auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
...@@ -366,11 +290,9 @@ struct store_C_col_major<CType, CFragT, 32, 32> ...@@ -366,11 +290,9 @@ struct store_C_col_major<CType, CFragT, 32, 32>
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 32); auto startOffset = col_major(startCoord2D, 32);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8 auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
// kMinorOffset == 1. // we can vector store 4 contiguous elements at a time.
// This means we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type; using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type; using CScalarFragT = vector_type<CRawT, VW>::type;
union union
...@@ -444,16 +366,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) ...@@ -444,16 +366,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
*/ */
struct GemmParams struct GemmParams
{ {
/** ck::index_t M = 16;
* @brief This constructor initializes the parameters for GEMM storage with default values. ck::index_t N = 16;
* ck::index_t K = 128;
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
GemmParams() : M(16), N(16), K(128) {}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA = -1; ck::index_t StrideA = -1;
ck::index_t StrideB = -1; ck::index_t StrideB = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment