You need to sign in or sign up before continuing.
Unverified Commit de174ada authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Remove redundant fill in SPMM kernel (#3166)

* remove redundant fill

* trigger ci
parent bbb6d4ee
...@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
std::fill(out_off, out_off + csr.num_rows * dim, 0);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}); });
}); });
...@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
DType *out_off = out.Ptr<DType>(); DType *out_off = out.Ptr<DType>();
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr; IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr; IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + csr.num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + csr.num_rows * dim, 0);
if (reduce == "max") { if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero); std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
...@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
...@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + vec_csr[etype].num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype]; const dgl_type_t src_id = ufeat_node_tids[etype];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment