Commit 1a90f021 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

WIP: Debug SCALE MFMA instruction

parent 1e054452
......@@ -519,12 +519,36 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const int32_t& scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
if(threadIdx.x == 0 || threadIdx.x == 32)
{
printf("thread: %u -- xA: %x\n", threadIdx.x, static_cast<uint32_t>(scale_a));
printf("thread: %u -- xB: %x\n", threadIdx.x, static_cast<uint32_t>(scale_b));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_a: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_a)));
// printf("intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> thread: %u -- scale_b: %f\n",
// threadIdx.x,
// static_cast<float>(ck::e8m0_bexp_t(scale_b)));
// for(size_t i = 0; i < 32; i++)
// {
// printf("thread: %u -- reg_a[%zu]: %f\n",
// threadIdx.x,
// i,
// type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_a)[i]}));
// // printf("thread: %u -- reg_a[%zu]: %f\n",
// // threadIdx.x,
// // i,
// // type_convert<float>(f8_t{static_cast<f8x32_t::data_v>(reg_b)[i]}));
// }
}
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
......
......@@ -110,16 +110,51 @@ bool run_mxmfma_test(ck::index_t init)
return pass;
}
TEST(MXMFMA, MXFP8MFMA16x16x128)
TEST(MXMFMA, MXFP8MFMA16x16x128i2)
{
auto AB_init = 2;
auto pass = run_mxmfma_test<f8_t, f8_t, half_t, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64)
TEST(MXMFMA, MXFP8MFMA32x32x64i2)
{
auto AB_init = 2;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA16x16x128i3)
{
auto AB_init = 3;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i3)
{
auto AB_init = 3;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA16x16x128i4)
{
auto AB_init = 4;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i4)
{
auto AB_init = 4;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
TEST(MXMFMA, MXFP8MFMA32x32x64i5)
{
auto AB_init = 5;
auto pass = run_mxmfma_test<f8_t, f8_t, float, ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init);
EXPECT_TRUE(pass);
}
......@@ -699,6 +699,82 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
// Scaled Matrix multiply-accumulate using MFMA units
// Accumulation intermediate = BLOCK_M x BLOCK_N
__syncthreads();
// printf("thread: %u -- fragXa: %d\n", threadIdx.x, fragXa);
printf("thread: %u -- fragA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
fragA.data_.dN[0],
fragA.data_.dN[1],
fragA.data_.dN[2],
fragA.data_.dN[3],
fragA.data_.dN[4],
fragA.data_.dN[5],
fragA.data_.dN[6],
fragA.data_.dN[7],
fragA.data_.dN[8],
fragA.data_.dN[9],
fragA.data_.dN[10],
fragA.data_.dN[11],
fragA.data_.dN[12],
fragA.data_.dN[13],
fragA.data_.dN[14],
fragA.data_.dN[15],
fragA.data_.dN[16],
fragA.data_.dN[17],
fragA.data_.dN[18],
fragA.data_.dN[19],
fragA.data_.dN[20],
fragA.data_.dN[21],
fragA.data_.dN[22],
fragA.data_.dN[23],
fragA.data_.dN[24],
fragA.data_.dN[25],
fragA.data_.dN[26],
fragA.data_.dN[27],
fragA.data_.dN[28],
fragA.data_.dN[29],
fragA.data_.dN[30],
fragA.data_.dN[31]);
printf("thread: %u -- fragB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
fragB.data_.dN[0],
fragB.data_.dN[1],
fragB.data_.dN[2],
fragB.data_.dN[3],
fragB.data_.dN[4],
fragB.data_.dN[5],
fragB.data_.dN[6],
fragB.data_.dN[7],
fragB.data_.dN[8],
fragB.data_.dN[9],
fragB.data_.dN[10],
fragB.data_.dN[11],
fragB.data_.dN[12],
fragB.data_.dN[13],
fragB.data_.dN[14],
fragB.data_.dN[15],
fragB.data_.dN[16],
fragB.data_.dN[17],
fragB.data_.dN[18],
fragB.data_.dN[19],
fragB.data_.dN[20],
fragB.data_.dN[21],
fragB.data_.dN[22],
fragB.data_.dN[23],
fragB.data_.dN[24],
fragB.data_.dN[25],
fragB.data_.dN[26],
fragB.data_.dN[27],
fragB.data_.dN[28],
fragB.data_.dN[29],
fragB.data_.dN[30],
fragB.data_.dN[31]);
//__builtin_amdgcn_mfma_ld_scale_b32(fragXa, 0, 0);
mfma_type_selector<AFragT, BFragT, AccumFragT, BLOCK_M, BLOCK_N>{}(
fragA, fragXa, fragB, fragXb, fragAcc);
__syncthreads();
......@@ -707,7 +783,7 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb,
{
fragC[i] = type_convert<CType>(fragAcc.template AsType<RawAccumFragT>()[Number<0>{}][i]);
}
__syncthreads();
auto storeC = store_C_row_major<CType, CFragT, BLOCK_M, BLOCK_N>{};
storeC(c, fragC);
}
......@@ -764,8 +840,8 @@ void RunHostGEMM(const Tensor<ADataType>& A,
{
for(size_t k = 0; k < K; k++)
{
a_m_k(m, k) = type_convert<float>(type_convert<ADataType>(
type_convert<float>(A(m, k)) * type_convert<float>(a_scales(m, k / BLOCK_X))));
a_m_k(m, k) =
type_convert<float>(A(m, k)) * type_convert<float>(a_scales(m, k / BLOCK_X));
}
}
......@@ -773,8 +849,8 @@ void RunHostGEMM(const Tensor<ADataType>& A,
{
for(size_t k = 0; k < K; k++)
{
b_k_n(k, n) = type_convert<float>(type_convert<BDataType>(
type_convert<float>(B(k, n)) * type_convert<float>(b_scales(k / BLOCK_X, n))));
b_k_n(k, n) =
type_convert<float>(B(k, n)) * type_convert<float>(b_scales(k / BLOCK_X, n));
}
}
......@@ -897,28 +973,60 @@ struct TestMXMFMA
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
// case 3:
// // expect small round off errors
// a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
// a_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
// b_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 0>{-9});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.3});
a_scales.GenerateTensorValue(GeneratorTensor_2<ScaleType>{126, 128}); // 1, 2
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
break;
case 5:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.0});
for(size_t i = 0; i < 32; i++)
{
a_m_k(0, i) = type_convert<ADataType>(1.0f);
}
for(size_t i = 32; i < 64; i++)
{
a_m_k(0, i) = type_convert<ADataType>(-2.0f);
}
// printf("f8 1: %x \n", type_convert<ADataType>(1.0f).data);
// printf("f8 -2: %x \n", type_convert<ADataType>(-2.0f).data);
a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
b_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 1>{-9});
a_scales(0, 0) = ScaleType{1.0f};
a_scales(0, 1) = ScaleType{0.5f};
b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{0.0f});
b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
for(size_t i = 0; i < 64; i++)
{
b_n_k(i, 0) = type_convert<BDataType>(1.0f);
}
break;
// case 3:
// // expect small round off errors
// a_m_k.GenerateTensorValue(GeneratorTensor_4<ADataType>(-1, 3));
// a_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// b_n_k.GenerateTensorValue(GeneratorTensor_4<BDataType>(1, 3));
// b_scales.GenerateTensorValue(
// GeneratorTensor_2<ScaleType>{126, 129}); // scales: {0.5, 1, 2}
// break;
// case 4:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 0>{-9});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// break;
// case 5:
// a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1.0f});
// a_scales.GenerateTensorValue(GeneratorTensor_1<ScaleType>{ScaleType{1.0f}});
// b_n_k.GenerateTensorValue(GeneratorTensor_1<BDataType>{1.0f});
// b_scales.GenerateTensorValue(GeneratorTensor_Sequential<ScaleType, 1>{-9});
// break;
case 6:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{0.00195312f});
......@@ -990,7 +1098,7 @@ struct TestMXMFMA
RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device);
#if 1
#if 0
#if 1
std::cout << "a:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
......@@ -1002,18 +1110,18 @@ struct TestMXMFMA
std::cout << std::endl;
break;
}
std::cout << "b:" << std::endl;
for(size_t i = 0; i < BLOCK_K; i++)
{
for(size_t j = 0; j < BLOCK_N; j++)
{
if(j == 0)
std::cout << type_convert<float>(b(i, j)) << " ";
}
std::cout << std::endl;
}
// std::cout << "b:" << std::endl;
// for(size_t i = 0; i < BLOCK_K; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// if(j == 0)
// std::cout << type_convert<float>(b(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
#if 1
#if 0
std::cout << "a_scale:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
{
......@@ -1023,15 +1131,15 @@ struct TestMXMFMA
}
std::cout << std::endl;
}
std::cout << "b_scale:" << std::endl;
for(size_t i = 0; i < BLOCK_K / BLOCK_X; i++)
{
for(size_t j = 0; j < BLOCK_N; j++)
{
std::cout << type_convert<float>(b_scales(i, j)) << " ";
}
std::cout << std::endl;
}
// std::cout << "b_scale:" << std::endl;
// for(size_t i = 0; i < BLOCK_K / BLOCK_X; i++)
// {
// for(size_t j = 0; j < BLOCK_N; j++)
// {
// std::cout << type_convert<float>(b_scales(i, j)) << " ";
// }
// std::cout << std::endl;
// }
#endif
std::cout << "c_device:" << std::endl;
for(size_t i = 0; i < BLOCK_M; i++)
......
#include <hip/hip_ext.h>
#include <hip/hip_runtime.h>
__global__ void kernel()
{
using dataAB = uint8_t __attribute__((ext_vector_type(32)));
using dataC = float __attribute__((ext_vector_type(16)));
using dataX = int32_t __attribute__((ext_vector_type(2)));
dataAB regA(0x38);
dataAB regB(0x38);
dataC regC(1.0f);
// dataC regCin(1.0f);
#if 1
// dataX xa{127, 127}; // 1.0
dataX xa(127 & 0xFF); // 1.0
dataX xb(127 & 0xFF); // 1.0
#else
dataX xa(0);
dataX xb(0);
#endif
#if 0
if(threadIdx.x == 0)
{
// xa = 127; // 1.0
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0x38; // 1.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
if(threadIdx.x == 32)
{
// xa = 126; // 0.5
for(size_t i = 0; i < 32; i++)
{
regA[i] = 0xC0; // -2.0
}
for(size_t i = 0; i < 32; i++)
{
regB[i] = 0x38; // 1.0
}
printf("thread: %u -- xA: %x\n", threadIdx.x, xa[threadIdx.x / 32]);
printf("thread: %u -- xB: %x\n", threadIdx.x, xb[threadIdx.x / 32]);
}
#endif
__syncthreads();
printf("thread: %u -- regA: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
regA[0],
regA[1],
regA[2],
regA[3],
regA[4],
regA[5],
regA[6],
regA[7],
regA[8],
regA[9],
regA[10],
regA[11],
regA[12],
regA[13],
regA[14],
regA[15],
regA[16],
regA[17],
regA[18],
regA[19],
regA[20],
regA[21],
regA[22],
regA[23],
regA[24],
regA[25],
regA[26],
regA[27],
regA[28],
regA[29],
regA[30],
regA[31]);
printf("thread: %u -- regB: %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x %x "
"%x %x %x %x %x %x %x %x %x %x\n",
threadIdx.x,
regB[0],
regB[1],
regB[2],
regB[3],
regB[4],
regB[5],
regB[6],
regB[7],
regB[8],
regB[9],
regB[10],
regB[11],
regB[12],
regB[13],
regB[14],
regB[15],
regB[16],
regB[17],
regB[18],
regB[19],
regB[20],
regB[21],
regB[22],
regB[23],
regB[24],
regB[25],
regB[26],
regB[27],
regB[28],
regB[29],
regB[30],
regB[31]);
//__builtin_amdgcn_mfma_ld_scale_b32(xb[threadIdx.x / 32], 0, 0);
regC = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(regA,
regB,
regC,
0, // cbsz
0, // blgp
0,
xa[threadIdx.x / 32],
0,
xb[threadIdx.x / 32]);
__syncthreads();
printf("thread: %u -- regC: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
threadIdx.x,
regC[0],
regC[1],
regC[2],
regC[3],
regC[4],
regC[5],
regC[6],
regC[7],
regC[8],
regC[9],
regC[10],
regC[11],
regC[12],
regC[13],
regC[14],
regC[15]);
// printf("thread: %u -- regCin: %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f %f\n",
// threadIdx.x,
// regCin[0],
// regCin[1],
// regCin[2],
// regCin[3],
// regCin[4],
// regCin[5],
// regCin[6],
// regCin[7],
// regCin[8],
// regCin[9],
// regCin[10],
// regCin[11],
// regCin[12],
// regCin[13],
// regCin[14],
// regCin[15]);
}
int main()
{
kernel<<<1, 64>>>();
return 0;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment