Unverified Commit b25bbe64 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

Loop reorder (#2201)

parent 233e8198
......@@ -44,19 +44,20 @@ void SpMMSumCsr(
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (int64_t k = 0; k < dim; ++k) {
DType accum = 0;
DType *out_off = O + rid * dim;
std::fill(out_off, out_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j;
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
accum += Op::Call(lhs_off, rhs_off);
const DType *lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType *rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
}
out_off[k] = accum;
}
}
}
......@@ -153,30 +154,28 @@ void SpMMCmpCsr(
DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim;
for (int64_t k = 0; k < dim; ++k) {
DType accum = Cmp::zero;
IdType ax = 0, aw = 0;
std::fill(out_off, out_off + dim, Cmp::zero);
if (Op::use_lhs)
std::fill(argx_off, argx_off + dim, 0);
if (Op::use_rhs)
std::fill(argw_off, argw_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off = Op::use_lhs? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off = Op::use_rhs? W + eid * rhs_dim + rhs_add : nullptr;
const DType val = Op::Call(lhs_off, rhs_off);
if (Cmp::Call(accum, val)) {
accum = val;
if (Cmp::Call(out_off[k], val)) {
out_off[k] = val;
if (Op::use_lhs)
ax = cid;
argx_off[k] = cid;
if (Op::use_rhs)
aw = eid;
argw_off[k] = eid;
}
}
out_off[k] = accum;
if (Op::use_lhs)
argx_off[k] = ax;
if (Op::use_rhs)
argw_off[k] = aw;
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment