"docs/git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "e3376abfd4f7923e3a66b13a8f039fbf21ae7f85"
Unverified Commit 75ec5826 authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

Add heterograph support in C kernels (#2882)



* SpMM for heterograph

* C APIs SDDMM heterograph

* passes initial result

* renamed eid with nid

* aggregation on same ntype for multiple etypes

* fix link check failure

* lint check part 2

* lint check part 3

* Fixed SpMMCmpCsr Min op

* added mem references

* fixed fill(Max/Min), added const

* removed newline

* brought back docstring
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 6383e649
...@@ -74,6 +74,34 @@ void SDDMMCsr(const std::string& op, ...@@ -74,6 +74,34 @@ void SDDMMCsr(const std::string& op,
}); });
} }
/*! \brief Generalized SDDMM on Csr format with Heterograph support. */
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_nid,
const std::vector<dgl_type_t>& rhs_nid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM for each relation type */
for (dgl_type_t etype = 0; etype < lhs_nid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_nid[etype]];
NDArray rhs = vec_rhs[rhs_nid[etype]];
NDArray out = vec_out[etype];
cpu::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
}
});
});
});
}
template void SDDMMCsr<kDLCPU, int32_t, 16>( template void SDDMMCsr<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
...@@ -99,6 +127,48 @@ template void SDDMMCsr<kDLCPU, int64_t, 64>( ...@@ -99,6 +127,48 @@ template void SDDMMCsr<kDLCPU, int64_t, 64>(
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
/*! \brief Generalized SDDMM on Coo format. */ /*! \brief Generalized SDDMM on Coo format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
......
...@@ -18,21 +18,100 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -18,21 +18,100 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray efeat, NDArray efeat,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
DType *out_off = out.Ptr<DType>();
std::fill(out_off, out_off + csr.num_rows * dim, 0);
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}); });
}); });
} else if (reduce == "max" || reduce == "min") { } else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
if (reduce == "max") DType *out_off = out.Ptr<DType>();
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + csr.num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + csr.num_rows * dim, 0);
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
else } else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>( cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
}
});
});
} else {
LOG(FATAL) << "Unsupported SpMM reducer: " << reduce;
}
}
/*! \brief Generalized SpMM on Csr format. */
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray> vec_out,
const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids) {
const int64_t dim = bcast.out_len;
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
std::fill(out_off, out_off + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id];
cpu::SpMMSumCsr<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
}
});
});
} else if (reduce == "max" || reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
// TODO(Israt): Ideally the for loop should go over num_ntypes
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
IdType* argX = Op::use_lhs ? static_cast<IdType*>(out_aux[0]->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(out_aux[1]->data) : nullptr;
if (Op::use_lhs) std::fill(argX, argX + vec_csr[etype].num_rows * dim, 0);
if (Op::use_rhs) std::fill(argW, argW + vec_csr[etype].num_rows * dim, 0);
}
/* Call SpMM for each relation type */
for (dgl_type_t etype = 0; etype < ufeat_node_tids.size(); ++etype) {
const dgl_type_t src_id = ufeat_node_tids[etype];
const dgl_type_t dst_id = out_node_tids[etype];
CSRMatrix csr = vec_csr[etype];
DType *out_off = vec_out[out_node_tids[etype]].Ptr<DType>();
NDArray ufeat = (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
NDArray out = vec_out[dst_id];
if (reduce == "max") {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Max<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Max<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
} else {
std::fill(out_off, out_off + csr.num_rows * dim, cpu::op::Min<DType>::zero);
cpu::SpMMCmpCsr<IdType, DType, Op, cpu::op::Min<DType>>(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
}
}
}); });
}); });
} else { } else {
...@@ -65,6 +144,48 @@ template void SpMMCsr<kDLCPU, int64_t, 64>( ...@@ -65,6 +144,48 @@ template void SpMMCsr<kDLCPU, int64_t, 64>(
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray> out, const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
/*! \brief Generalized SpMM on Coo format. */ /*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
......
...@@ -57,7 +57,6 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -57,7 +57,6 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim; DType* out_off = O + rid * dim;
std::fill(out_off, out_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
...@@ -72,7 +71,6 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -72,7 +71,6 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
for (IdType rid = 0; rid < csr.num_rows; ++rid) { for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1]; const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim; DType* out_off = O + rid * dim;
std::fill(out_off, out_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
...@@ -180,9 +178,6 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -180,9 +178,6 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
DType* out_off = O + rid * dim; DType* out_off = O + rid * dim;
IdType* argx_off = argX + rid * dim; IdType* argx_off = argX + rid * dim;
IdType* argw_off = argW + rid * dim; IdType* argw_off = argW + rid * dim;
std::fill(out_off, out_off + dim, Cmp::zero);
if (Op::use_lhs) std::fill(argx_off, argx_off + dim, 0);
if (Op::use_rhs) std::fill(argw_off, argw_off + dim, 0);
for (IdType j = row_start; j < row_end; ++j) { for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j]; const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j; const IdType eid = has_idx ? edges[j] : j;
......
...@@ -52,6 +52,54 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -52,6 +52,54 @@ void SpMM(const std::string& op, const std::string& reduce,
}); });
} }
/*! \brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
std::vector<NDArray> ufeat_vec,
std::vector<NDArray> efeat_vec,
std::vector<NDArray> out,
std::vector<NDArray> out_aux) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
std::vector<CSRMatrix> vec_graph;
std::vector<dgl_type_t> ufeat_eid;
std::vector<dgl_type_t> efeat_eid;
std::vector<dgl_type_t> out_eid;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_graph.push_back(graph->GetCSCMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype);
ufeat_eid.push_back(pair.first);
efeat_eid.push_back(etype);
out_eid.push_back(pair.second);
}
NDArray efeat = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[efeat_eid[0]];
NDArray ufeat = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[ufeat_eid[0]];
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
// TODO(Israt): Change it to ATEN_XPU_SWITCH_CUDA when cuda codes are modified
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[out_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, bits>(
op, reduce, bcast, vec_graph,
ufeat_vec, efeat_vec, out, out_aux,
ufeat_eid, out_eid);
// TODO(Israt): Enable it when COO support is added
// } else if (format == SparseFormat::kCOO) {
// SpMMCoo<XPU, IdType, bits>(
// op, reduce, bcast, graph->GetCOOMatrix(0),
// ufeat, vec_efeat, out, out_aux);
// }
} else {
LOG(FATAL) << "SpMM only supports CSC foramt for heterpgraph";
}
});
});
});
}
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMM(const std::string& op, void SDDMM(const std::string& op,
HeteroGraphPtr graph, HeteroGraphPtr graph,
...@@ -83,6 +131,51 @@ void SDDMM(const std::string& op, ...@@ -83,6 +131,51 @@ void SDDMM(const std::string& op,
}); });
} }
/*! \brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op,
HeteroGraphPtr graph,
std::vector<NDArray> lhs,
std::vector<NDArray> rhs,
std::vector<NDArray> out,
int lhs_target,
int rhs_target) {
// TODO(Israt): change it to COO_CODE
SparseFormat format = graph->SelectFormat(0, CSR_CODE);
std::vector<CSRMatrix> vec_csr;
std::vector<dgl_type_t> lhs_eid;
std::vector<dgl_type_t> rhs_eid;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
auto pair = graph->meta_graph()->FindEdge(etype);
lhs_eid.push_back(pair.first);
rhs_eid.push_back(pair.second);
}
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
// TODO(Israt): change it to ATEN_XPU_SWITCH_CUDA when cuda codes are modified
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", {
if (format == SparseFormat::kCSR) {
SDDMMCsrHetero<XPU, IdType, bits>(
op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
// TODO(Israt): Enable it when COO support is added
// } else if (format == SparseFormat::kCOO) {
// SDDMMCoo<XPU, IdType, bits>(
// op, bcast, graph->GetCOOMatrix(0),
// lhs, rhs, out, lhs_target, rhs_target);
} else {
LOG(FATAL) << "SDDMM only supports CSR foramts";
}
});
});
});
}
NDArray GetEdgeMapping(HeteroGraphRef graph) { NDArray GetEdgeMapping(HeteroGraphRef graph) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
...@@ -217,6 +310,45 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") ...@@ -217,6 +310,45 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
List<Value> list_U = args[3];
List<Value> list_E = args[4];
List<Value> list_V = args[5];
NDArray ArgU = args[6];
NDArray ArgE = args[7];
std::vector<NDArray> U_vec;
std::vector<NDArray> V_vec;
std::vector<NDArray> E_vec;
U_vec.reserve(list_U.size());
V_vec.reserve(list_V.size());
E_vec.reserve(list_E.size());
for (Value val : list_U) {
U_vec.push_back(val->data);
}
for (Value val : list_V) {
V_vec.push_back(val->data);
}
for (Value val : list_E) {
E_vec.push_back(val->data);
}
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V_vec[dst_id], ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
}
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, V_vec, {ArgU, ArgE});
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
...@@ -232,6 +364,7 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") ...@@ -232,6 +364,7 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
CheckShape( CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{lhs_target, rhs_target, 1}, {lhs_target, rhs_target, 1},
...@@ -240,6 +373,36 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") ...@@ -240,6 +373,36 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target); SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_lhs = args[2];
List<Value> list_rhs = args[3];
List<Value> list_out = args[4];
int lhs_target = args[5];
int rhs_target = args[6];
std::vector<NDArray> vec_lhs;
std::vector<NDArray> vec_rhs;
std::vector<NDArray> vec_out;
vec_lhs.reserve(list_lhs.size());
vec_rhs.reserve(list_rhs.size());
vec_out.reserve(list_out.size());
for (Value val : list_lhs) {
vec_lhs.push_back(val->data);
}
for (Value val : list_rhs) {
vec_rhs.push_back(val->data);
}
for (Value val : list_out) {
vec_out.push_back(val->data);
}
SDDMMHetero(op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string op = args[0]; const std::string op = args[0];
......
...@@ -29,6 +29,20 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -29,6 +29,20 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
/*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
with heterograph support.
*/
template <int XPU, typename IdType, int bits>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat,
std::vector<NDArray> out,
const std::vector<NDArray>& out_aux,
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/*! /*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format. * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/ */
...@@ -53,6 +67,21 @@ void SDDMMCsr(const std::string& op, ...@@ -53,6 +67,21 @@ void SDDMMCsr(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target); int rhs_target);
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr
format with heterograph support.
*/
template <int XPU, typename IdType, int bits>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment