/** * Copyright (c) 2020 by Contributors * @file array/kernel.cc * @brief New kernels */ #include #include #ifdef USE_TVM #include #include #endif // USE_TVM #include "../c_api_common.h" #include "./check.h" #include "kernel_decl.h" using namespace dgl::runtime; namespace dgl { namespace aten { namespace {} // namespace /** @brief Generalized Sparse Matrix-Matrix Multiplication. */ void SpMM( const std::string& op, const std::string& reduce, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat, NDArray out, std::vector out_aux) { // TODO(zihao): format tuning SparseFormat format = graph->SelectFormat(0, CSC_CODE); const auto& bcast = CalcBcastOff(op, ufeat, efeat); ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", { if (format == SparseFormat::kCSC) { SpMMCsr( op, reduce, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out, out_aux); } else if (format == SparseFormat::kCOO) { SpMMCoo( op, reduce, bcast, graph->GetCOOMatrix(0), ufeat, efeat, out, out_aux); } else { LOG(FATAL) << "SpMM only supports CSC and COO formats"; } }); }); }); } /** @brief Generalized segmented dense Matrix-Matrix Multiplication. */ void SegmentMM( const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A, bool A_trans, bool B_trans) { CHECK_EQ(A->ndim, 2) << "segment_mm expects a 2D tensor for the first input."; CHECK_EQ(B->ndim, 3) << "segment_mm expects a 3D tensor for the second input."; CHECK(!A_trans); if (B_trans) { CHECK_EQ(A->shape[1], B->shape[2]) << "segment_mm expects A.shape[1] == B.shape[2] when B_trans=True"; } else { CHECK_EQ(A->shape[1], B->shape[1]) << "segment_mm expects A.shape[1] == B.shape[1]"; } CHECK_EQ(B->shape[0], seglen_A.NumElements()) << "segment_mm expects len(seglen_A) == B.shape[0]"; CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU) << "segment_mm expects seglen_A to be on CPU."; CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device"; ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", { ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { SegmentMM(A, B, C, seglen_A, A_trans, B_trans); }); }); }); } void SegmentMMBackwardB( const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) { CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor " "for the first input."; CHECK_EQ(dC->ndim, 2) << "segment_mm_backward operator expects a 2D tensor " "for the second input."; CHECK_EQ(seglen->ctx.device_type, kDGLCPU) << "segment_mm expects seglen to be on CPU."; ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", { ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { SegmentMMBackwardB(A, dC, dB, seglen); }); }); }); } /** @brief Generalized Dense Matrix-Matrix Multiplication according to relation * types. */ void GatherMM( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b) { CHECK_EQ(A->ndim, 2) << "gather_mm operator expects a 2D tensor for the first input."; CHECK_EQ(B->ndim, 3) << "gather_mm operator expects a 3D tensor for the second input."; CHECK(A->ctx == B->ctx) << "gather_mm expects all arguments to be on the same device."; if (aten::IsNullArray(idx_a)) { CHECK_EQ(A->shape[0], idx_b->shape[0]) << "gather_mm expects len(idx_b) == A.shape[0] when idx_a is None."; CHECK(A->ctx == idx_b->ctx) << "gather_mm expects all arguments to be on the same device."; } else if (aten::IsNullArray(idx_b)) { CHECK_EQ(B->shape[0], idx_a->shape[0]) << "gather_mm expects len(idx_a) == B.shape[0] when idx_b is None."; CHECK(A->ctx == idx_a->ctx) << "gather_mm expects all arguments to be on the same device."; } else { CHECK_EQ(idx_a->shape[0], idx_b->shape[0]) << "gather_mm expects len(idx_a) == len(idx_b) when both idx_a and " "idx_b are given."; CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx) << "gather_mm expects all arguments to be on the same device."; } const auto idtype = aten::IsNullArray(idx_a) ? idx_b->dtype : idx_a->dtype; ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_ID_TYPE_SWITCH(idtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { GatherMM(A, B, C, idx_a, idx_b); }); }); }); } /** @brief Generalized Dense Matrix-Matrix Multiplication according to relation * types. */ void GatherMMScatter( const NDArray A, const NDArray B, NDArray C, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c) { CHECK_EQ(A->ndim, 2) << "gather_mm_scatter expects a 2D tensor for the first input."; CHECK(A->ctx == B->ctx) << "gather_mm_scatter expects all arguments to be on the same device."; if (!aten::IsNullArray(idx_c)) CHECK(A->ctx == idx_c->ctx) << "gather_mm_scatter expects all arguments to be on the same device."; if (aten::IsNullArray(idx_a) && !aten::IsNullArray(idx_b)) { CHECK_EQ(A->shape[0], idx_b->shape[0]) << "gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is " "None."; CHECK(A->ctx == idx_b->ctx) << "gather_mm_scatter expects all arguments to be on the same device."; } else if (aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) { CHECK_EQ(B->shape[0], idx_a->shape[0]) << "gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is " "None."; CHECK(A->ctx == idx_a->ctx) << "gather_mm_scatter expects all arguments to be on the same device."; } else if (!aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) { CHECK_EQ(idx_a->shape[0], idx_b->shape[0]) << "gather_mm_scatter expects len(idx_a) == len(idx_b) " << "when both idx_a and idx_b are given."; CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx) << "gather_mm_scatter expects all arguments to be on the same device."; } ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { GatherMMScatter(A, B, C, idx_a, idx_b, idx_c); }); }); }); } /** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph * support. */ void SpMMHetero( const std::string& op, const std::string& reduce, HeteroGraphPtr graph, const std::vector& ufeat_vec, const std::vector& efeat_vec, std::vector* out, std::vector>* out_aux) { SparseFormat format = graph->SelectFormat(0, CSC_CODE); std::vector vec_graph; std::vector ufeat_eid; std::vector efeat_eid; std::vector out_eid; auto pair = graph->meta_graph()->FindEdge(0); // first etype NDArray ufeat_etype0 = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first]; NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0]; for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { vec_graph.push_back(graph->GetCSCMatrix(etype)); auto pair = graph->meta_graph()->FindEdge(etype); ufeat_eid.push_back(pair.first); efeat_eid.push_back(etype); out_eid.push_back(pair.second); if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1]) LOG(FATAL) << "Column width of the input node features of all etypes " "must be same."; if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1]) LOG(FATAL) << "Column width of the input edge features of all etypes " "must be same."; } const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0); ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_ID_TYPE_SWITCH( graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS( (*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", { if (format == SparseFormat::kCSC) { SpMMCsrHetero( op, reduce, bcast, vec_graph, ufeat_vec, efeat_vec, out, out_aux, ufeat_eid, out_eid); } else { // TODO(Israt): Add support for COO format LOG(FATAL) << "SpMM only supports CSC format for graphs with number " << "of relation types > 1"; } }); }); }); } /** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */ void SDDMM( const std::string& op, HeteroGraphPtr graph, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) { // TODO(zihao): format tuning SparseFormat format = graph->SelectFormat(0, COO_CODE); const auto& bcast = CalcBcastOff(op, lhs, rhs); ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", { if (format == SparseFormat::kCSR) { SDDMMCsr( op, bcast, graph->GetCSRMatrix(0), lhs, rhs, out, lhs_target, rhs_target); } else if (format == SparseFormat::kCOO) { SDDMMCoo( op, bcast, graph->GetCOOMatrix(0), lhs, rhs, out, lhs_target, rhs_target); } else { LOG(FATAL) << "SDDMM only supports CSR and COO formats"; } }); }); }); } /** * @brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'. * * @param graph The input graph. * @param target 'u', 'v' or 'e'. The target of the lhs or rhs data of an etype. * @param etype Relation type of the input graph. */ int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) { auto pair = graph->meta_graph()->FindEdge(etype); if (target == 0) return pair.first; if (target == 2) return pair.second; return etype; } /** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */ void SDDMMHetero( const std::string& op, HeteroGraphPtr graph, std::vector lhs, std::vector rhs, std::vector out, int lhs_target, int rhs_target) { SparseFormat format = graph->SelectFormat(0, COO_CODE); std::vector lhs_eid; std::vector rhs_eid; for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype)); rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype)); } const auto& bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]); ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS( out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", { if (format == SparseFormat::kCSR) { std::vector vec_csr; for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { vec_csr.push_back(graph->GetCSRMatrix(etype)); } SDDMMCsrHetero( op, bcast, vec_csr, lhs, rhs, out, lhs_target, rhs_target, lhs_eid, rhs_eid); } else if (format == SparseFormat::kCOO) { std::vector vec_coo; for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { vec_coo.push_back(graph->GetCOOMatrix(etype)); } SDDMMCooHetero( op, bcast, vec_coo, lhs, rhs, out, lhs_target, rhs_target, lhs_eid, rhs_eid); } else { LOG(FATAL) << "SDDMM only supports CSR and COO formats"; } }); }); }); } /** @brief Generalized Edge_softmax op for forward */ void Edge_softmax_forward( const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat, NDArray out) { // TODO(zhejiang): add gpu op for edge_softmax const auto& bcast = CalcBcastOff(op, ufeat, efeat); ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS( out->dtype, Dtype, XPU, "edge_softmax out data", { Edge_softmax_csr_forward( op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out); }); }); }); } /** @brief Generalized Edge_softmax op for backward */ void Edge_softmax_backward( const std::string& op, HeteroGraphPtr graph, NDArray out, NDArray sds, NDArray back_out, NDArray ufeat) { // TODO(zhejiang): add gpu op for edge_softmax const auto& bcast = CalcBcastOff(op, ufeat, sds); ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS( out->dtype, Dtype, XPU, "edge_softmax out data_back", { Edge_softmax_csr_backward( op, bcast, graph->GetCSCMatrix(0), out, sds, back_out); }); }); }); } NDArray GetEdgeMapping(HeteroGraphRef graph) { SparseFormat format = graph->SelectFormat(0, CSC_CODE); if (format == SparseFormat::kCSC) { return graph.sptr()->GetCSCMatrix(0).data; } else { return NullArray(); } } /** @brief Segment reduce dispatch function. */ void SegmentReduceDispatch( const std::string& op, NDArray feat, NDArray offsets, NDArray out, NDArray arg) { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", { ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", { SegmentReduce(op, feat, offsets, out, arg); }); }); }); } /** @brief Scatter Add (on first dimension) dispatch function. */ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "ScatterAdd", { ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", { ScatterAdd(feat, idx, out); }); }); }); } /** @brief Update gradients (reduce op max/min) dispatch function on * heterogeneous graph. */ void UpdateGradMinMaxDispatchHetero( const HeteroGraphPtr& graph, const std::string& op, const std::vector& feat, const std::vector& idx, const std::vector& idx_etype, std::vector* out) { auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype auto src_id = pair.first; ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", { ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS( feat[src_id]->dtype, Dtype, XPU, "Feature data", { UpdateGradMinMax_hetero( graph, op, feat, idx, idx_etype, out); }); }); }); } /** @brief Backward segment cmp dispatch function.*/ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", { ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", { BackwardSegmentCmp(feat, arg, out); }); }); }); } std::pair CSRMM( CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights) { CHECK_EQ(A.num_cols, B.num_rows) << "The number of nodes of destination node type of the first graph must " "be the " "same as the number of nodes of source node type of the second graph."; CheckCtx( A.indptr->ctx, {A_weights, B_weights}, {"A's edge weights", "B's edge weights"}); CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match."; CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match."; CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match."; std::pair ret; ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, "CSRMM", { ATEN_ID_TYPE_SWITCH(A.indptr->dtype, IdType, { ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", { ret = CSRMM(A, A_weights, B, B_weights); }); }); }); return ret; } std::pair CSRSum( const std::vector& A, const std::vector& A_weights) { CHECK(A.size() > 0) << "The list of graphs must not be empty."; CHECK_EQ(A.size(), A_weights.size()) << "The list of edge weights must have the same length as the list of " "graphs."; const auto ctx = A[0].indptr->ctx; const auto idtype = A[0].indptr->dtype; const auto dtype = A_weights[0]->dtype; const auto num_rows = A[0].num_rows; const auto num_cols = A[0].num_cols; for (size_t i = 0; i < A.size(); ++i) { CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal."; CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal."; CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0]) << "Shape of edge weights does not match the number of edges."; CHECK_EQ(A_weights[i]->ctx, ctx) << "The devices of edge weights must be " "the same as that of the graphs."; CHECK_EQ(A_weights[i]->dtype, dtype) << "The data types of all edge weights must be equal."; CHECK_EQ(A[i].num_rows, num_rows) << "Graphs must have the same number of nodes."; CHECK_EQ(A[i].num_cols, num_cols) << "Graphs must have the same number of nodes."; } std::pair ret; ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "CSRSum", { ATEN_ID_TYPE_SWITCH(idtype, IdType, { ATEN_FLOAT_TYPE_SWITCH(dtype, DType, "Edge weights", { ret = CSRSum(A, A_weights); }); }); }); return ret; } DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; const std::string reduce_op = args[2]; NDArray U = args[3]; NDArray E = args[4]; NDArray V = args[5]; NDArray ArgU = args[6]; NDArray ArgE = args[7]; CheckCtx( graph->Context(), {U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"}); CheckContiguous( {U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"}); CHECK_EQ(graph->NumEdgeTypes(), 1); auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. const dgl_type_t src_vtype = pair.first; const dgl_type_t dst_vtype = pair.second; CheckShape( {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, {0, 1, 2, 2, 2}, {U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"}); SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray A = args[0]; NDArray B = args[1]; NDArray C = args[2]; NDArray idx_a = args[3]; NDArray idx_b = args[4]; GatherMM(A, B, C, idx_a, idx_b); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray A = args[0]; NDArray B = args[1]; NDArray C = args[2]; NDArray idx_a = args[3]; NDArray idx_b = args[4]; NDArray idx_c = args[5]; GatherMMScatter(A, B, C, idx_a, idx_b, idx_c); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray A = args[0]; NDArray B = args[1]; NDArray C = args[2]; NDArray seglen_A = args[3]; bool A_trans = args[4]; bool B_trans = args[5]; SegmentMM(A, B, C, seglen_A, A_trans, B_trans); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMMBackwardB") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray A = args[0]; NDArray dC = args[1]; NDArray dB = args[2]; NDArray seglen = args[3]; SegmentMMBackwardB(A, dC, dB, seglen); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; NDArray U = args[2]; NDArray E = args[3]; NDArray V = args[4]; Edge_softmax_forward(op, graph.sptr(), U, E, V); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_backward") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; NDArray out = args[2]; NDArray sds = args[3]; NDArray back_out = args[4]; NDArray ufeat = args[5]; Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; const std::string reduce_op = args[2]; List list_U = args[3]; List list_E = args[4]; List list_V = args[5]; List list_ArgU = args[6]; List list_ArgE = args[7]; List list_ArgU_ntype = args[8]; List list_ArgE_etype = args[9]; std::vector> Arg_vec; // ArgU + ArgE for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype Arg_vec.push_back(std::vector()); } std::vector U_vec = ListValueToVector(list_U); std::vector V_vec = ListValueToVector(list_V); std::vector E_vec = ListValueToVector(list_E); Arg_vec[0] = ListValueToVector(list_ArgU); Arg_vec[1] = ListValueToVector(list_ArgE); Arg_vec[2] = ListValueToVector(list_ArgU_ntype); Arg_vec[3] = ListValueToVector(list_ArgE_etype); for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { auto pair = graph->meta_graph()->FindEdge(etype); const dgl_id_t src_id = pair.first; const dgl_id_t dst_id = pair.second; NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id]; NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype]; CheckCtx( graph->Context(), {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"}); CheckContiguous( {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"}); } SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; NDArray lhs = args[2]; NDArray rhs = args[3]; NDArray out = args[4]; int lhs_target = args[5]; int rhs_target = args[6]; CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"}); CHECK_EQ(graph->NumEdgeTypes(), 1); auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. const dgl_type_t src_vtype = pair.first; const dgl_type_t dst_vtype = pair.second; CheckShape( {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, {lhs_target, rhs_target, 1}, {lhs, rhs, out}, {"U_data", "E_data", "V_data"}); SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMMHetero") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; List list_lhs = args[2]; List list_rhs = args[3]; List list_out = args[4]; int lhs_target = args[5]; int rhs_target = args[6]; std::vector vec_lhs; std::vector vec_rhs; std::vector vec_out; vec_lhs.reserve(list_lhs.size()); vec_rhs.reserve(list_rhs.size()); vec_out.reserve(list_out.size()); for (Value val : list_lhs) { vec_lhs.push_back(val->data); } for (Value val : list_rhs) { vec_rhs.push_back(val->data); } for (Value val : list_out) { vec_out.push_back(val->data); } SDDMMHetero( op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce") .set_body([](DGLArgs args, DGLRetValue* rv) { const std::string op = args[0]; NDArray feat = args[1]; NDArray offsets = args[2]; NDArray out = args[3]; NDArray arg = args[4]; CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"}); CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"}); SegmentReduceDispatch(op, feat, offsets, out, arg); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray feat = args[0]; NDArray idx = args[1]; NDArray out = args[2]; CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"}); CheckContiguous({feat, idx, out}, {"feat", "idx", "out"}); ScatterAddDispatch(feat, idx, out); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; const std::string op = args[1]; List list_feat = args[2]; List list_idx = args[3]; List list_idx_etype = args[4]; List list_out = args[5]; std::vector vec_feat = ListValueToVector(list_feat); std::vector vec_idx = ListValueToVector(list_idx); std::vector vec_idx_etype = ListValueToVector(list_idx_etype); std::vector vec_out = ListValueToVector(list_out); // CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"}); // CheckContiguous({feat, idx, out}, {"feat", "idx", "out"}); UpdateGradMinMaxDispatchHetero( graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp") .set_body([](DGLArgs args, DGLRetValue* rv) { NDArray feat = args[0]; NDArray arg = args[1]; NDArray out = args[2]; CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"}); CheckContiguous({feat, arg, out}, {"feat", "arg", "out"}); BackwardSegmentCmpDispatch(feat, arg, out); }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; *rv = GetEdgeMapping(graph); }); /** * @brief Sparse matrix multiplication with graph interface. * * @param A_ref The left operand. * @param A_weights The edge weights of graph A. * @param B_ref The right operand. * @param B_weights The edge weights of graph B. * @param num_vtypes The number of vertex types of the graph to be returned. * @return A pair consisting of the new graph as well as its edge weights. */ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM") .set_body([](DGLArgs args, DGLRetValue* rv) { const HeteroGraphRef A_ref = args[0]; NDArray A_weights = args[1]; const HeteroGraphRef B_ref = args[2]; NDArray B_weights = args[3]; int num_vtypes = args[4]; const HeteroGraphPtr A = A_ref.sptr(); const HeteroGraphPtr B = B_ref.sptr(); CHECK_EQ(A->NumEdgeTypes(), 1) << "The first graph must have only one edge type."; CHECK_EQ(B->NumEdgeTypes(), 1) << "The second graph must have only one edge type."; const auto A_csr = A->GetCSRMatrix(0); const auto B_csr = B->GetCSRMatrix(0); auto result = CSRMM(A_csr, A_weights, B_csr, B_weights); List ret; ret.push_back( HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE))); ret.push_back(Value(MakeValue(result.second))); *rv = ret; }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum") .set_body([](DGLArgs args, DGLRetValue* rv) { List A_refs = args[0]; List A_weights = args[1]; std::vector weights = ListValueToVector(A_weights); std::vector mats; mats.reserve(A_refs.size()); int num_vtypes = 0; for (auto A_ref : A_refs) { const HeteroGraphPtr A = A_ref.sptr(); CHECK_EQ(A->NumEdgeTypes(), 1) << "Graphs must have only one edge type."; mats.push_back(A->GetCSRMatrix(0)); if (num_vtypes == 0) num_vtypes = A->NumVertexTypes(); } auto result = CSRSum(mats, weights); List ret; ret.push_back( HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE))); ret.push_back(Value(MakeValue(result.second))); *rv = ret; }); DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask") .set_body([](DGLArgs args, DGLRetValue* rv) { const HeteroGraphRef A_ref = args[0]; NDArray A_weights = args[1]; const HeteroGraphRef B_ref = args[2]; const HeteroGraphPtr A = A_ref.sptr(); const HeteroGraphPtr B = B_ref.sptr(); CHECK_EQ(A->NumEdgeTypes(), 1) << "Both graphs must have only one edge type."; CHECK_EQ(B->NumEdgeTypes(), 1) << "Both graphs must have only one edge type."; const CSRMatrix& A_csr = A->GetCSRMatrix(0); const COOMatrix& B_coo = B->GetCOOMatrix(0); CHECK_EQ(A_csr.num_rows, B_coo.num_rows) << "Both graphs must have the same number of nodes."; CHECK_EQ(A_csr.num_cols, B_coo.num_cols) << "Both graphs must have the same number of nodes."; NDArray result; ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", { result = aten::CSRGetData(A_csr, B_coo.row, B_coo.col, A_weights, 0.); }); *rv = result; }); #ifdef USE_TVM DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule") .set_body([](DGLArgs args, DGLRetValue* rv) { const std::string path = args[0]; dgl::featgraph::LoadFeatGraphModule(path); }); DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef graph = args[0]; NDArray lhs = args[1]; NDArray rhs = args[2]; NDArray out = args[3]; CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"}); CHECK_EQ(graph->NumEdgeTypes(), 1); // auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the // graph. const dgl_type_t src_vtype = pair.first; const dgl_type_t // dst_vtype = pair.second; CheckShape( // {graph->NumVertices(src_vtype), graph->NumEdges(0), // graph->NumVertices(dst_vtype)}, {lhs_target, rhs_target, 1}, {lhs, // rhs, out}, // {"U_data", "E_data", "V_data"}); COOMatrix coo = graph.sptr()->GetCOOMatrix(0); dgl::featgraph::SDDMMTreeReduction( DLPackConvert::ToDLPack(coo.row), DLPackConvert::ToDLPack(coo.col), DLPackConvert::ToDLPack(lhs), DLPackConvert::ToDLPack(rhs), DLPackConvert::ToDLPack(out)); }); #endif // USE_TVM } // namespace aten } // namespace dgl