Commit d6a7acf6 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed int4 to bf16 conversion

parent 9de3a085
......@@ -65,7 +65,7 @@ using DeviceGemmV2Instance =
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
#endif
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, CDataType, CDataType, false, PermuteB>;
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, false, PermuteB>;
// clang-format on
......
......@@ -73,10 +73,10 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
fp32_intermediates[3] -= 8388616.f;
vector_type<bhalf_t, 4> res;
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632));
res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632));
__byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632));
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(
__byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632));
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
......@@ -135,8 +135,8 @@ struct PassThroughPack8
#if 1
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
......
......@@ -45,6 +45,24 @@ __global__ void
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
// int q = 0x01234567;
// ck::vector_type<ck::bhalf_t, 8> res;
// res.template AsType<ck::bhalf4_t>()(ck::Number<0>{}) = ck::pki4_to_bhalf4(q >> 16);
// res.template AsType<ck::bhalf4_t>()(ck::Number<1>{}) = ck::pki4_to_bhalf4(q);
// if(threadIdx.x == 0 && blockIdx.x == 0)
// printf("%f %f %f %f %f %f %f %f\n",
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<0>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<1>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<2>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<3>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<4>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<5>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<6>{}]),
// ck::type_convert<float>(res.template AsType<ck::bhalf_t>()[Number<7>{}])
//);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment