Commit a32a8776 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/130: use int

parent 54b47924
...@@ -43,12 +43,12 @@ void calculate( ...@@ -43,12 +43,12 @@ void calculate(
std::swap(a, b); std::swap(a, b);
} }
#pragma omp parallel for collapse(3) #pragma omp parallel for collapse(3)
for (ptrdiff_t i = 0; i < ptrdiff_t(info.batch); ++i) { for (int i = 0; i < static_cast<int>(info.batch); ++i) {
for (ptrdiff_t m_ = 0; m_ < ptrdiff_t(info.m); ++m_) { for (int m_ = 0; m_ < static_cast<int>(info.m); ++m_) {
for (ptrdiff_t n_ = 0; n_ < ptrdiff_t(info.n); ++n_) { for (int n_ = 0; n_ < static_cast<int>(info.n); ++n_) {
auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride; auto c_ = reinterpret_cast<Tdata *>(c) + i * info.c_matrix.stride + m_ * info.c_matrix.row_stride + n_ * info.c_matrix.col_stride;
float sum = 0; float sum = 0;
for (size_t k_ = 0; k_ < info.k; ++k_) { for (int k_ = 0; k_ < static_cast<int>(info.k); ++k_) {
auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride; auto a_ = reinterpret_cast<const Tdata *>(a) + i * info.a_matrix.stride + m_ * info.a_matrix.row_stride + k_ * info.a_matrix.col_stride;
auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride; auto b_ = reinterpret_cast<const Tdata *>(b) + i * info.b_matrix.stride + n_ * info.b_matrix.col_stride + k_ * info.b_matrix.row_stride;
if constexpr (std::is_same<Tdata, fp16_t>::value) { if constexpr (std::is_same<Tdata, fp16_t>::value) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment