Commit 2bd601e1 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Cleanup

parent 718c7abb
......@@ -30,14 +30,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 16, 16>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>{};
op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
}
};
......@@ -46,14 +40,8 @@ struct mfma_type_selector<AFragT, BFragT, AccumFragT, 32, 32>
{
__device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc)
{
#if 1
auto op = mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>{};
op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc);
#else
ignore = fragA;
ignore = fragB;
ignore = fragAcc;
#endif
}
};
......@@ -131,43 +119,9 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
// BLOCK_M is a stride in A matrix
auto startOffset = col_major(startCoord2D, BLOCK_M);
auto kOffset = col_major(stepCoord2D, BLOCK_M);
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
#if 0
auto fragA = AScalarFragT{
bit_cast<ARawT>(input_ptr[startOffset]), // XXX v[0] = Reg 0 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 1 * kOffset]), // XXX v[1] = Reg 0 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 2 * kOffset]), // XXX v[2] = Reg 0 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 3 * kOffset]), // XXX v[3] = Reg 0 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 4 * kOffset]), // XXX v[4] = Reg 1 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 5 * kOffset]), // XXX v[5] = Reg 1 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 6 * kOffset]), // XXX v[6] = Reg 1 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 7 * kOffset]), // XXX v[7] = Reg 1 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 8 * kOffset]), // XXX v[8] = Reg 2 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 9 * kOffset]), // XXX v[9] = Reg 2 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 10 * kOffset]), // XXX v[10] = Reg 2 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 11 * kOffset]), // XXX v[11] = Reg 2 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 12 * kOffset]), // XXX v[12] = Reg 3 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 13 * kOffset]), // XXX v[13] = Reg 3 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 14 * kOffset]), // XXX v[14] = Reg 3 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 15 * kOffset]), // XXX v[15] = Reg 3 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 16 * kOffset]), // XXX v[16] = Reg 4 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 17 * kOffset]), // XXX v[17] = Reg 4 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 18 * kOffset]), // XXX v[18] = Reg 4 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 19 * kOffset]), // XXX v[19] = Reg 4 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 20 * kOffset]), // XXX v[20] = Reg 5 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 21 * kOffset]), // XXX v[21] = Reg 5 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 22 * kOffset]), // XXX v[22] = Reg 5 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 23 * kOffset]), // XXX v[23] = Reg 5 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 24 * kOffset]), // XXX v[24] = Reg 6 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 25 * kOffset]), // XXX v[25] = Reg 6 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 26 * kOffset]), // XXX v[26] = Reg 6 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 27 * kOffset]), // XXX v[27] = Reg 6 [24:31]
bit_cast<ARawT>(input_ptr[startOffset + 28 * kOffset]), // XXX v[28] = Reg 7 [0:7]
bit_cast<ARawT>(input_ptr[startOffset + 29 * kOffset]), // XXX v[29] = Reg 7 [8:15]
bit_cast<ARawT>(input_ptr[startOffset + 30 * kOffset]), // XXX v[30] = Reg 7 [16:23]
bit_cast<ARawT>(input_ptr[startOffset + 31 * kOffset])}; // XXX v[31] = Reg 7 [24:31]
#else
// kOffset == BLOCK_M
// This means every BLOCK_M element is loaded into output vector
auto fragA = AScalarFragT{};
#pragma unroll VW
for(uint32_t i = 0; i < VW; i++)
......@@ -175,7 +129,6 @@ __device__ AFragT load_A_col_major(AType const* input_ptr)
fragA[i] = bit_cast<ARawT>(input_ptr[startOffset + i * kOffset]);
}
#endif
return fragA;
}
......@@ -237,15 +190,12 @@ __device__ BFragT load_B_col_major(BType const* input_ptr)
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row
threadIdx.x % BLOCK_N); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, BLOCK_K);
// auto kOffset = col_major(stepCoord2D, BLOCK_K);
// kOffset == 1
auto const* fragPtr = reinterpret_cast<BFragT const*>(input_ptr + startOffset);
return *fragPtr;
}
......@@ -278,29 +228,16 @@ struct store_C_col_major<CType, CFragT, 16, 16>
static constexpr uint32_t VW = vectorSize(cFrag); // 4
static constexpr uint32_t Dim = 16;
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
// Each thread will load 4 elements.
// We need to know where they start, and where the next elements are.
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// auto stepCoord2D = std::make_pair(1u, 0u);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 16);
// auto kOffset = col_major(stepCoord2D, 16); // 1
// kOffset == 1
auto* fragPtr = reinterpret_cast<CFragT*>(output + startOffset);
*fragPtr = cFrag;
}
......@@ -343,34 +280,19 @@ struct store_C_col_major<CType, CFragT, 32, 32>
static constexpr uint32_t Dim = 32;
static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8
#if 1
for(int i = 0; i < vectorSize(cFrag); ++i)
{
printf("threadIdx.x = %d; cFrag[%d] = %f\n",
static_cast<int>(threadIdx.x),
i,
static_cast<float>(cFrag[i]));
}
#endif
auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row
threadIdx.x % Dim); // Col
// Minor step for each 'chunk'
// auto minorStepCoord2D = std::make_pair(1u, 0u);
// Major step between 'chunks'
auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0);
// Flatten to 1D col_major offsets.
auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; };
auto startOffset = col_major(startCoord2D, 32);
// auto kMinorOffset = col_major(minorStepCoord2D, 32); // 1
auto startOffset = col_major(startCoord2D, 32);
auto kMajorOffset = col_major(majorStepCoord2D, 32); // 8
// kMinorOffset == 1.
// This means we can vector store 4 contiguous elements at a time.
// we can vector store 4 contiguous elements at a time.
using CRawT = typename scalar_type<CFragT>::type;
using CScalarFragT = vector_type<CRawT, VW>::type;
union
......@@ -444,16 +366,9 @@ __global__ void matmul(const AType* a, const BType* b, CType* c)
*/
struct GemmParams
{
/**
* @brief This constructor initializes the parameters for GEMM storage with default values.
*
* A[16x128] * B[128x16] = C[16x16], all row major.
*/
GemmParams() : M(16), N(16), K(128) {}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 128;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment