Commit 21885e9c authored by ltqin's avatar ltqin
Browse files

add print data

parent 86834375
...@@ -57,6 +57,22 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds ...@@ -57,6 +57,22 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSkipLds
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
{
os << "[" << std::endl;
for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
{
os << "[";
for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
{
os << std::setw(5) << static_cast<float>(matrix(x, y));
}
os << "]" << std::endl;
}
os << "]";
return os;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = 0;
...@@ -136,8 +152,9 @@ int main(int argc, char* argv[]) ...@@ -136,8 +152,9 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); // a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
...@@ -199,6 +216,14 @@ int main(int argc, char* argv[]) ...@@ -199,6 +216,14 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#if 1
{
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl;
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
}
#endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment