"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "4e77382034c6906881249b92f658569b4328d2e8"
Commit ef77a1ca authored by qinletao's avatar qinletao
Browse files

format log out

parent dcdbed2a
...@@ -47,13 +47,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl ...@@ -47,13 +47,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if 1 #if 1
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>; < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = double; using ADataType = double;
using BDataType = double; using BDataType = double;
using CDataType = double; using CDataType = double;
using AccDataType = double; using AccDataType = double;
#else #else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>; < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
using ADataType = float; using ADataType = float;
using BDataType = float; using BDataType = float;
using CDataType = float; using CDataType = float;
...@@ -64,6 +64,23 @@ using AccDataType = float; ...@@ -64,6 +64,23 @@ using AccDataType = float;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>; ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
template <typename DataType>
std::ostream& void show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
{
os << "[" << std::endl;
for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
{
os << "[";
for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
{
os << std::setw(4) << static_cast<float>(matrix(x, y));
}
os << "]" << std::endl;
}
os << "]";
return os;
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = 0; bool do_verification = 0;
...@@ -144,8 +161,8 @@ int main(int argc, char* argv[]) ...@@ -144,8 +161,8 @@ int main(int argc, char* argv[])
break; break;
default: default:
// a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); // a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
// b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); // b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
} }
...@@ -208,13 +225,12 @@ int main(int argc, char* argv[]) ...@@ -208,13 +225,12 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#if 0 #if 1
{ {
LogRangeAsType<AccDataType>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
LogRangeAsType<AccDataType>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
LogRangeAsType<AccDataType>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl;
LogRangeAsType<AccDataType>(std::cout << "c_host : ", c_m_n_host_result.mData, ",") show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
<< std::endl;
} }
#endif #endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
......
...@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device:: ...@@ -49,14 +49,14 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
ConvFwdDefault, // ConvForwardSpecialization ConvFwdDefault, // ConvForwardSpecialization
NumDimSpatial, // NumDimSpatial NumDimSpatial, // NumDimSpatial
256, // BlockSize 256, // BlockSize
128, // MPerBlock 64, // MPerBlock
128, // NPerBlock 64, // NPerBlock
4, // K0PerBlock 4, // K0PerBlock
2, // K1 2, // K1
16, // MPerXDL 16, // MPerXDL
16, // NPerXDL 16, // NPerXDL
4, // MXdlPerWave 2, // MXdlPerWave
4, // NXdlPerWave 2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
......
...@@ -25,7 +25,7 @@ enum struct MfmaInstr ...@@ -25,7 +25,7 @@ enum struct MfmaInstr
mfma_f32_16x16x8bf16, mfma_f32_16x16x8bf16,
mfma_i32_32x32x8i8, mfma_i32_32x32x8i8,
mfma_i32_16x16x16i8, mfma_i32_16x16x16i8,
mfma_f64_16x16x4f64, mfma_f64_16x16x4f64
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
......
...@@ -63,8 +63,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -63,8 +63,8 @@ struct ReferenceGemm : public device::BaseOperator
AccDataType v_a; AccDataType v_a;
AccDataType v_b; AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment