Unverified Commit 08b60eb1 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add SpMM and SDDMM on CSR and COO in dgl include headers (#5016)

parent 0038a29b
......@@ -431,13 +431,9 @@ COOMatrix COOReorder(
* value array.
*/
std::pair<COOMatrix, FloatArray> COOLaborSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = NullArray(),
int importance_sampling = 0,
IdArray random_seed = NullArray(),
IdArray NIDs = NullArray());
COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob = NullArray(), int importance_sampling = 0,
IdArray random_seed = NullArray(), IdArray NIDs = NullArray());
/**
* @brief Randomly select a fixed number of non-zero entries along each given
......@@ -785,6 +781,48 @@ COOMatrix COOSliceContiguousChunk(
*/
COOMatrix COOLineGraph(const COOMatrix& coo, bool backtracking);
/**
* @brief Generalized Sparse Matrix-Matrix Multiplication on COO.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'.
* @param op The reduce operator, could be `sum`, `min`, `max'.
* @param coo The COO we apply SpMM on.
* @param ufeat The source node feature.
* @param efeat The edge feature.
* @param out The output feature on destination nodes.
* @param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void COOSpMM(
const std::string& op, const std::string& reduce, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/** @brief COOSpMM C interface without std::string. */
void COOSpMM(
const char* op, const char* reduce, const COOMatrix& coo, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on COO.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'.
* @param coo The COO we apply SpMM on.
* @param ufeat The source node feature.
* @param vfeat The destination node feature.
* @param out The output feature on edge.
* @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
* @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
*/
void COOSDDMM(
const std::string& op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target);
/** @brief COOSDDMM C interface without std::string. */
void COOSDDMM(
const char* op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target);
} // namespace aten
} // namespace dgl
......
......@@ -459,16 +459,13 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
* array.
*/
std::pair<COOMatrix, FloatArray> CSRLaborSampling(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = NullArray(),
int importance_sampling = 0,
IdArray random_seed = NullArray(),
IdArray NIDs = NullArray());
CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob = NullArray(), int importance_sampling = 0,
IdArray random_seed = NullArray(), IdArray NIDs = NullArray());
/*!
* @brief Randomly select a fixed number of non-zero entries along each given row independently.
* @brief Randomly select a fixed number of non-zero entries along each given
* row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
......@@ -895,6 +892,48 @@ CSRMatrix CSRSliceContiguousChunk(
const std::vector<uint64_t>& src_vertex_range,
const std::vector<uint64_t>& dst_vertex_range);
/**
* @brief Generalized Sparse Matrix-Matrix Multiplication on CSR.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `copy_u`, `copy_e'.
* @param op The reduce operator, could be `sum`, `min`, `max'.
* @param csr The CSR we apply SpMM on.
* @param ufeat The source node feature.
* @param efeat The edge feature.
* @param out The output feature on destination nodes.
* @param out_aux A list of NDArray's that contains auxiliary information such
* as the argmax on source nodes and edges for reduce operators such as
* `min` and `max`.
*/
void CSRSpMM(
const std::string& op, const std::string& reduce, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/** @brief CSRSpMM C interface without std::string. */
void CSRSpMM(
const char* op, const char* reduce, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on CSR.
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
* `dot`, `copy_u`, `copy_e'.
* @param csr The CSR we apply SpMM on.
* @param ufeat The source node feature.
* @param vfeat The destination node feature.
* @param out The output feature on edge.
* @param lhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
* @param rhs_target Type of `ufeat` (0: source, 1: edge, 2: destination).
*/
void CSRSDDMM(
const std::string& op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target);
/** @brief CSRSDDMM C interface without std::string. */
void CSRSDDMM(
const char* op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target);
} // namespace aten
} // namespace dgl
......
......@@ -4,6 +4,7 @@
* @brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
......@@ -15,6 +16,7 @@
#include "../c_api_common.h"
#include "./arith.h"
#include "./array_op.h"
#include "./kernel_decl.h"
using namespace dgl::runtime;
......@@ -545,9 +547,8 @@ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
int importance_sampling, IdArray random_seed, IdArray NIDs) {
std::pair<COOMatrix, FloatArray> ret;
ATEN_CSR_SWITCH_CUDA_UVA(mat, rows, XPU, IdType, "CSRLaborSampling", {
const auto dtype = IsNullArray(prob)
? DGLDataTypeTraits<float>::dtype
: prob->dtype;
const auto dtype =
IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;
ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
ret = impl::CSRLaborSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, importance_sampling, random_seed, NIDs);
......@@ -819,9 +820,8 @@ std::pair<COOMatrix, FloatArray> COOLaborSampling(
int importance_sampling, IdArray random_seed, IdArray NIDs) {
std::pair<COOMatrix, FloatArray> ret;
ATEN_COO_SWITCH(mat, XPU, IdType, "COOLaborSampling", {
const auto dtype = IsNullArray(prob)
? DGLDataTypeTraits<float>::dtype
: prob->dtype;
const auto dtype =
IsNullArray(prob) ? DGLDataTypeTraits<float>::dtype : prob->dtype;
ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "probability", {
ret = impl::COOLaborSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, importance_sampling, random_seed, NIDs);
......@@ -1088,6 +1088,93 @@ Frontiers DGLDFSLabeledEdges(
return ret;
}
void CSRSpMM(
const std::string& op, const std::string& reduce, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
SpMMCsr<XPU, IdType, Dtype>(
op, reduce, bcast, csr, ufeat, efeat, out, out_aux);
});
});
});
}
void CSRSpMM(
const char* op, const char* reduce, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
CSRSpMM(
std::string(op), std::string(reduce), csr, ufeat, efeat, out, out_aux);
}
void CSRSDDMM(
const std::string& op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target) {
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(csr.indptr->ctx.device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
SDDMMCsr<XPU, IdType, Dtype>(
op, bcast, csr, ufeat, efeat, out, lhs_target, rhs_target);
});
});
});
}
void CSRSDDMM(
const char* op, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target) {
return CSRSDDMM(
std::string(op), csr, ufeat, efeat, out, lhs_target, rhs_target);
}
void COOSpMM(
const std::string& op, const std::string& reduce, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
SpMMCoo<XPU, IdType, Dtype>(
op, reduce, bcast, coo, ufeat, efeat, out, out_aux);
});
});
});
}
void COOSpMM(
const char* op, const char* reduce, const COOMatrix& coo, NDArray ufeat,
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
COOSpMM(
std::string(op), std::string(reduce), coo, ufeat, efeat, out, out_aux);
}
void COOSDDMM(
const std::string& op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target) {
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH_CUDA(coo.row->ctx.device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
SDDMMCoo<XPU, IdType, Dtype>(
op, bcast, coo, ufeat, efeat, out, lhs_target, rhs_target);
});
});
});
}
void COOSDDMM(
const char* op, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, int lhs_target, int rhs_target) {
COOSDDMM(std::string(op), coo, ufeat, efeat, out, lhs_target, rhs_target);
}
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([](DGLArgs args, DGLRetValue* rv) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment