Commit 2d4fb7d5 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Latest reproducer for SCALE MFMA

parent f3af1da6
#include <hip/hip_ext.h>
#include <hip/hip_runtime.h>
template <typename Y,
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)
{
static_assert(__has_builtin(__builtin_bit_cast), "");
static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
return __builtin_bit_cast(Y, x);
}
__global__ void kernel()
{
using dataAB = uint8_t __attribute__((ext_vector_type(32)));
using dataC = float __attribute__((ext_vector_type(16)));
using dataX = int32_t __attribute__((ext_vector_type(2)));
dataAB regA(0x38);
dataAB regB(0x38);
dataC regC(1.0f);
dataAB regA(0);
dataAB regB(0);
dataC regC(0.0f);
// dataC regCin(1.0f);
#if 1
// dataX xa{127, 127}; // 1.0
dataX xa(127 & 0xFF); // 1.0
dataX xb(127 & 0xFF); // 1.0
#if 0
dataX xa{0x3F800000, 0x3F800000};
dataX xb(0x3F800000);
#elif 0
dataX xa{0x3F000000, 0x3F000000};
dataX xb(0x3F800000);
#elif 0
dataX xa{0x3F800000, 0x3F000000};
dataX xb(0x3F800000);
#elif 1
dataX xa{0x7F, 0x7E}; // expect 64 at c(0,0)
dataX xb(0x3F800000);
// dataX xb(0x7F);
#else
dataX xa(0);
dataX xb(0);
#endif
#if 0
#if 1
if(threadIdx.x == 0)
{
// xa = 127; // 1.0
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0x38; // 1.0
......@@ -31,27 +50,23 @@ __global__ void kernel()
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
if(threadIdx.x == 32)
{
// xa = 126; // 0.5
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0xC0; // -2.0
regA[i] = 0x40; // 2.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
#endif
__syncthreads();
#if 1
printf("thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
......@@ -123,7 +138,7 @@ __global__ void kernel()
regB[29],
regB[30],
regB[31]);
#endif
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
regC = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(regA,
regB,
......@@ -136,6 +151,11 @@ __global__ void kernel()
xb[threadIdx.x / 32]);
__syncthreads();
if(threadIdx.x == 0 || threadIdx.x == 32)
{
printf("thread: %u -- xA: %x\n", threadIdx.x, bit_cast<int32_t>(xa[threadIdx.x / 32]));
printf("thread: %u -- xB: %x\n", threadIdx.x, bit_cast<int32_t>(xb[threadIdx.x / 32]));
}
printf("thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
threadIdx.x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment