Commit 04ba72c7 authored by Jing Zhang's avatar Jing Zhang
Browse files

formatting

parent 5a3c2297
......@@ -53,8 +53,8 @@ struct ReferenceGemmTranspose : public device::BaseOperator
auto f_mk_kn_m0m1n0n1 = [&](auto m0, auto m1, auto n0, auto n1) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
const int m = m0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1] + m1;
const int n = n0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3] + n1;
const int m = m0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1] + m1;
const int n = n0 * arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3] + n1;
float v_acc = 0;
......@@ -76,8 +76,11 @@ struct ReferenceGemmTranspose : public device::BaseOperator
arg.c_m0_m1_n0_n1_(m0, m1, n0, n1) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_m0m1n0n1, arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[0], arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1], arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[2],arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3])(
make_ParallelTensorFunctor(f_mk_kn_m0m1n0n1,
arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[0],
arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[1],
arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[2],
arg.c_m0_m1_n0_n1_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment