Unverified Commit c667efbd authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #132 from InfiniTensor/issue/130

issue/130: GEMM算子CPU平台的omp重构
parents 4f8afafd 8e2192f4
...@@ -40,31 +40,34 @@ void calculate( ...@@ -40,31 +40,34 @@ void calculate(
std::swap(a, b); std::swap(a, b);
} }
for (size_t i = 0; i < info.batch; ++i) { #pragma omp parallel for
for (size_t m_ = 0; m_ < info.m; ++m_) { for (ptrdiff_t index = 0; index < ptrdiff_t(info.batch * info.m * info.n); ++index) {
for (size_t n_ = 0; n_ < info.n; ++n_) { size_t ind = index;
auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; size_t n_ = ind % info.n;
float sum = 0; ind /= info.n;
for (size_t k_ = 0; k_ < info.k; ++k_) { size_t m_ = ind % info.m;
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; ind /= info.m;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; size_t i = ind;
if constexpr (std::is_same<Tdata, fp16_t>::value) { auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride;
sum += utils::cast<float>(*a_) * utils::cast<float>(*b_); float sum = 0;
} else { for (int k_ = 0; k_ < static_cast<int>(info.k); ++k_) {
sum += *a_ * (*b_); auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
} auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
} if constexpr (std::is_same<Tdata, fp16_t>::value) {
if constexpr (std::is_same<Tdata, fp16_t>::value) { sum += utils::cast<float>(*a_) * utils::cast<float>(*b_);
if (beta == 0) { } else {
*c_ = utils::cast<fp16_t>(alpha * sum); sum += *a_ * (*b_);
} else {
*c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
}
} }
} }
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if (beta == 0) {
*c_ = utils::cast<fp16_t>(alpha * sum);
} else {
*c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment