Commit a2e2fc6c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Print scale to thread mapping for `mfma_scale_f32_32x32x64`

parent ca567c60
...@@ -258,6 +258,96 @@ __global__ void kernel_b_scale_mapping() ...@@ -258,6 +258,96 @@ __global__ void kernel_b_scale_mapping()
dataAB regA(0); dataAB regA(0);
dataAB regB(0); dataAB regB(0);
dataC regC(0.0f); dataC regC(0.0f);
dataX xa(0x3F800000);
dataX xb{bit_cast<int32_t>(threadIdx.x) + 0x7F,
bit_cast<int32_t>(threadIdx.x) + 0x7F}; // 127{2^0}, 127+1{2^1},...,127+63{2^63}
// fill first row of A with 1.0
if(threadIdx.x == 0 || threadIdx.x == 32)
{
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0x38; // 1.0
}
}
// verify scale mapping for each row
for(int colId = 0; colId < 32; colId++)
{
for(int testId = 0; testId < 64; testId++)
{
if(threadIdx.x == 0 && false)
{
printf("testId: %d\n", testId);
}
regB = dataAB(0);
regC = dataC(0.0f);
if(threadIdx.x == 0 + colId && testId < 32) // first half
{
// set b(testId, colId) = 1.0
regB[testId] = 0x38; // 1.0
}
else if(threadIdx.x == 32 + colId && 32 <= testId) // upper 32 entries
{
// set b(testId, colId) = 1.0
regB[testId % 32] = 0x38; // 1.0
}
__syncthreads();
regC = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(regA,
regB,
regC,
0, // cbsz
0, // blgp
0,
xa[threadIdx.x / 32],
0,
xb[threadIdx.x / 32]);
__syncthreads();
if(threadIdx.x % 32 == 0 && false) // row 0
{
printf(
"thread: %u -- xA: %x\n", threadIdx.x, bit_cast<int32_t>(xa[threadIdx.x / 32]));
printf(
"thread: %u -- xB: %x\n", threadIdx.x, bit_cast<int32_t>(xb[threadIdx.x / 32]));
}
// Size | BLOCK_N | BLOCK_N |
// N | 0 ... 31 | 0 ... 31 |
// Thread Id | 0 ... 31 | 32 ... 63 | Vector
// Register Element ------------ ------------- Element
// Reg0 | M0 | M4 | v[0]
// Reg1 | M1 | M5 | v[1]
// Reg2 | M2 | M6 | v[2]
// Reg3 | M3 | M7 | v[3]
// ____________ _____________
// Reg4 | M8 | M12 | v[4]
// Reg5 | M9 | M13 | v[5]
// Reg6 | M10 | M14 | v[6]
// Reg7 | M11 | M15 | v[7]
// ____________ _____________
// Reg8 | M16 | M20 | v[8]
// Reg9 | M17 | M21 | v[9]
// Reg10 | M18 | M22 | v[10]
// Reg11 | M19 | M23 | v[11]
// ____________ _____________
// Reg12 | M24 | M28 | v[12]
// Reg13 | M25 | M29 | v[13]
// Reg14 | M26 | M30 | v[14]
// Reg15 | M27 | M31 | v[15]
if(threadIdx.x == colId)
{
printf("b(%d,%d) is scaled from thread %f\n", testId, colId, log2f(regC[0]));
}
}
if(threadIdx.x == 32)
{
printf("\n");
}
}
} }
int main() int main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment