Commit bc981f39 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/130: delete collapse

parent a32a8776
......@@ -42,32 +42,35 @@ void calculate(
if (info.is_transed) {
std::swap(a, b);
}
#pragma omp parallel for collapse(3)
for (int i = 0; i < static_cast<int>(info.batch); ++i) {
for (int m_ = 0; m_ < static_cast<int>(info.m); ++m_) {
for (int n_ = 0; n_ < static_cast<int>(info.n); ++n_) {
auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride;
float sum = 0;
for (int k_ = 0; k_ < static_cast<int>(info.k); ++k_) {
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
sum += utils::cast<float>(*a_) * utils::cast<float>(*b_);
} else {
sum += *a_ * (*b_);
}
}
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if (beta == 0) {
*c_ = utils::cast<fp16_t>(alpha * sum);
} else {
*c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
}
const size_t m_n = info.m * info.n;
const size_t n = info.n;
#pragma omp parallel for
for (ptrdiff_t index = 0; index < ptrdiff_t(info.batch * info.m * info.n); ++index) {
size_t i, m_, n_;
i = index / m_n;
size_t rem = index - i * m_n; // 替代 `%` 用减法
m_ = rem / n;
n_ = rem - m_ * n; // 替代 `%` 用减法
auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride;
float sum = 0;
for (int k_ = 0; k_ < static_cast<int>(info.k); ++k_) {
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, fp16_t>::value) {
sum += utils::cast<float>(*a_) * utils::cast<float>(*b_);
} else {
sum += *a_ * (*b_);
}
}
if constexpr (std::is_same<Tdata, fp16_t>::value) {
if (beta == 0) {
*c_ = utils::cast<fp16_t>(alpha * sum);
} else {
*c_ = utils::cast<fp16_t>(beta * utils::cast<float>(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment