Unverified Commit de174ada authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Remove redundant fill in SPMM kernel (#3166)

* remove redundant fill

* trigger ci
parent bbb6d4ee
......@@ -22,8 +22,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
std::fill(out_off, out_off + csr.num_rows * dim, 0);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
......@@ -33,8 +31,6 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
DType *out_off = out.Ptr<DType>();
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + csr.num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + csr.num_rows * dim, 0);
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
......@@ -66,11 +62,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
......@@ -86,13 +77,6 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + vec_csr[etype].num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment