Commit ef77a1ca authored by qinletao's avatar qinletao
Browse files

format log out

parent dcdbed2a
......@@ -47,13 +47,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>;
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
#else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = float;
using BDataType = float;
using CDataType = float;
......@@ -64,6 +64,23 @@ using AccDataType = float;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
template <typename DataType>
std::ostream& void show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
{
os << "[" << std::endl;
for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
{
os << "[";
for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
{
os << std::setw(4) << static_cast<float>(matrix(x, y));
}
os << "]" << std::endl;
}
os << "]";
return os;
}
int main(int argc, char* argv[])
{
bool do_verification = 0;
......@@ -144,8 +161,8 @@ int main(int argc, char* argv[])
break;
default:
// a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
// b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
}
......@@ -208,13 +225,12 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
#if 0
#if 1
{
LogRangeAsType<AccDataType>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<AccDataType>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<AccDataType>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl;
LogRangeAsType<AccDataType>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl;
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
}
#endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
......
......@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
64, // MPerBlock
64, // NPerBlock
4, // K0PerBlock
2, // K1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......
......@@ -25,7 +25,7 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64,
mfma_f64_16x16x4f64
};
template <MfmaInstr instr>
......
......@@ -63,8 +63,8 @@ struct ReferenceGemm : public device::BaseOperator
AccDataType v_a;
AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment