"...composable_kernel.git" did not exist on "6bc9ee057ea9a08ed63662fd534ace12c43dd82f"
Commit 05ab9105 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed reference and host_tensor

parent 205e0365
...@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -158,6 +158,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
b_k_n(0, 0) = 0xaa;
b_k_n(1, 1) = 0xaa;
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -207,31 +210,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
//auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
//auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
//auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
// a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
//ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
//pass &= ck::utils::check_err(c_m_n_device_result, pass &= ck::utils::check_err(c_m_n_device_result,
// c_m_n_host_result, c_m_n_host_result,
// "Error: Incorrect results!", "Error: Incorrect results!",
// get_rtol<CDataType>(), get_rtol<CDataType>(),
// get_atol<CDataType>()); get_atol<CDataType>());
//for(int i = 0; i < M; i++) std::cout << "c_m_n_device_result: " << std::endl;
//{ for(int i = 0; i < M; i++)
// for(int j = 0; j < N; j++) {
// { for(int j = 0; j < N; j++)
// std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ","; {
// } std::cout << ck::type_convert<float>(c_m_n_device_result(i, j)) << ",";
// std::cout << std::endl; }
//} std::cout << std::endl;
}
std::cout << "c_m_n_host_result: " << std::endl;
for(int i = 0; i < M; i++)
{
for(int j = 0; j < N; j++)
{
std::cout << ck::type_convert<float>(c_m_n_host_result(i, j)) << ",";
}
std::cout << std::endl;
}
} }
if(config.time_kernel) if(config.time_kernel)
......
...@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16> ...@@ -157,8 +157,8 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{ {
//reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
//reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
} }
}; };
......
...@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -84,6 +84,17 @@ struct ReferenceGemm : public device::BaseOperator
{ {
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
} }
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
arg.b_element_op_(v_b, i4);
}
else else
{ {
arg.b_element_op_(v_b, arg.b_k_n_(k, n)); arg.b_element_op_(v_b, arg.b_k_n_(k, n));
......
...@@ -322,7 +322,12 @@ struct Tensor ...@@ -322,7 +322,12 @@ struct Tensor
std::size_t GetElementSize() const { return mDesc.GetElementSize(); } std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } std::size_t GetElementSpaceSize() const {
if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
return mDesc.GetElementSpaceSize() / 2;
else
return mDesc.GetElementSpaceSize();
}
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
...@@ -469,29 +474,64 @@ struct Tensor ...@@ -469,29 +474,64 @@ struct Tensor
template <typename... Is> template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
return mDesc.GetOffsetFromMultiIndex(is...); if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
else
{
return mDesc.GetOffsetFromMultiIndex(is...);
}
} }
template <typename... Is> template <typename... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
template <typename... Is> template <typename... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(is...)]; if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
}
} }
T& operator()(std::vector<std::size_t> idx) T& operator()(std::vector<std::size_t> idx)
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
const T& operator()(std::vector<std::size_t> idx) const const T& operator()(std::vector<std::size_t> idx) const
{ {
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; if constexpr(ck::is_same_v<T, ck::pk_i4_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
else
{
return mData[mDesc.GetOffsetFromMultiIndex(idx)];
}
} }
typename Data::iterator begin() { return mData.begin(); } typename Data::iterator begin() { return mData.begin(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment