Commit 05ab9105 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed reference and host_tensor

parent 205e0365
......@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
b_k_n(0, 0) = 0xaa;
b_k_n(1, 1) = 0xaa;
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool pass = true;
if(config.do_verification)
{
//auto ref_gemm = ReferenceGemmInstance{};
//auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
//auto ref_argument = ref_gemm.MakeArgument(
// a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
//ref_invoker.Run(ref_argument);
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
//pass &= ck::utils::check_err(c_m_n_device_result,
// c_m_n_host_result,
// "Error: Incorrect results!",
// get_rtol<CDataType>(),
// get_atol<CDataType>());
//for(int i = 0; i < M; i++)
//{
// for(int j = 0; j < N; j++)
// {
// std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
// }
// std::cout << std::endl;
//}
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
std::cout << "c_m_n_device_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
}
std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
}
if(config.time_kernel)
......
......@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
//reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
//reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
}
};
......
......@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
arg.b_element_op_(v_b, i4);
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
......
......@@ -322,7 +322,12 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSize() const {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2;
else
return mDesc.GetElementSpaceSize();
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
......@@ -468,31 +473,66 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
}
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
}
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
}
typename Data::iterator begin() { return mData.begin(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment