Unverified Commit b25bbe64 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

Loop reorder (#2201)

parent 233e8198
...@@ -44,19 +44,20 @@ void SpMMSumCsr( ...@@ -44,19 +44,20 @@ void SpMMSumCsr(
#pragma omp parallel for #pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim; DType *out_off = O + rid * dim;
for (int64_t k = 0; k < dim; ++k) { std::fill(out_off, out_off + dim, 0);
DType accum = 0;
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr; const DType *lhs_off =
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr; Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
accum += Op::Call(lhs_off, rhs_off); const DType *rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
} }
out_off[k] = accum;
} }
} }
} }
...@@ -153,30 +154,28 @@ void SpMMCmpCsr( ...@@ -153,30 +154,28 @@ void SpMMCmpCsr(
DType* out_off = O + rid * dim; DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim; IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim; IdType* argw_off = argW + rid * dim;
for (int64_t k = 0; k < dim; ++k) { std::fill(out_off, out_off + dim, Cmp::zero);
DType accum = Cmp::zero; if (Op::use_lhs)
IdType ax = 0, aw = 0; std::fill(argx_off, argx_off + dim, 0);
if (Op::use_rhs)
std::fill(argw_off, argw_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j; const IdType eid = has_idx? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k; const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k; const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr; const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr; const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off); const DType val = Op::Call(lhs_off, rhs_off);
if (Cmp::Call(accum, val)) { if (Cmp::Call(out_off[k], val)) {
accum = val; out_off[k] = val;
if (Op::use_lhs) if (Op::use_lhs)
ax = cid; argx_off[k] = cid;
if (Op::use_rhs) if (Op::use_rhs)
aw = eid; argw_off[k] = eid;
} }
} }
out_off[k] = accum;
if (Op::use_lhs)
argx_off[k] = ax;
if (Op::use_rhs)
argw_off[k] = aw;
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment