Unverified Commit f8d4264e authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Feature] Neighbor-hood based sampling APIs (#1251)

* WIP: working on random choices

* light slice

* basic CPU impl

* add python binding; fix CreateFromCOO and CreateFromCSR returning unitgraph

* simple test case works

* fix bug in slicing probability array

* fix bug in getting the correct relation graph

* fix bug in creating placeholder graph

* enable omp

* add cpp test

* sample topk

* add in|out_subgraph

* try fix lint; passed all unittests

* fix lint

* fix msvc compile; add sorted flag and constructors

* fix msvc

* coosort

* COOSort; CSRRowWiseSampling; CSRRowWiseTopk

* WIP: remove DType in CSR and COO; Restrict data array to be IdArray

* fix all CSR ops for missing data array

* compiled

* passed tests

* lint

* test sampling out edge

* test different per-relation fanout/k values

* fix bug in random choice

* finished cpptest

* fix compile

* Add induced edges

* add check

* fixed bug in sampling on hypersparse graph; add tests

* add ascending flag

* in|out_subgraph returns subgraph and induced eid

* address comments

* lint

* fix
parent c7c0fd0e
...@@ -200,7 +200,8 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -200,7 +200,8 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
/*! /*!
* \brief Plain CSR matrix * \brief Plain CSR matrix
* *
* The column indices are 0-based and are not necessarily sorted. * The column indices are 0-based and are not necessarily sorted. The data array stores
* integer ids for reading edge features.
* *
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
...@@ -208,18 +209,28 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths); ...@@ -208,18 +209,28 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
*/ */
struct CSRMatrix { struct CSRMatrix {
/*! \brief the dense shape of the matrix */ /*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols; int64_t num_rows = 0, num_cols = 0;
/*! \brief CSR index arrays */ /*! \brief CSR index arrays */
runtime::NDArray indptr, indices; IdArray indptr, indices;
/*! \brief data array, could be empty. */ /*! \brief data index array. When empty, assume it is from 0 to NNZ - 1. */
runtime::NDArray data; IdArray data;
/*! \brief whether the column indices per row are sorted */ /*! \brief whether the column indices per row are sorted */
bool sorted; bool sorted = false;
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols,
IdArray parr, IdArray iarr, IdArray darr = IdArray(),
bool sorted_flag = false)
: num_rows(nrows), num_cols(ncols), indptr(parr), indices(iarr),
data(darr), sorted(sorted_flag) {}
}; };
/*! /*!
* \brief Plain COO structure * \brief Plain COO structure
* *
* The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries * Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in * that have the same row, col indices. It corresponds to multigraph in
* graph terminology. * graph terminology.
...@@ -228,13 +239,23 @@ struct CSRMatrix { ...@@ -228,13 +239,23 @@ struct CSRMatrix {
*/ */
struct COOMatrix { struct COOMatrix {
/*! \brief the dense shape of the matrix */ /*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols; int64_t num_rows = 0, num_cols = 0;
/*! \brief COO index arrays */ /*! \brief COO index arrays */
runtime::NDArray row, col; IdArray row, col;
/*! /*! \brief data index array. When empty, assume it is from 0 to NNZ - 1. */
* \brief data array, could be empty. When empty, assume it is from 0 to NNZ - 1. IdArray data;
*/ /*! \brief whether the row indices are sorted */
runtime::NDArray data; bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */
bool col_sorted = false;
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols,
IdArray rarr, IdArray carr, IdArray darr = IdArray(),
bool rsorted = false, bool csorted = false)
: num_rows(nrows), num_cols(ncols), row(rarr), col(carr), data(darr),
row_sorted(rsorted), col_sorted(csorted) {}
}; };
///////////////////////// CSR routines ////////////////////////// ///////////////////////// CSR routines //////////////////////////
...@@ -330,8 +351,106 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -330,8 +351,106 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
/*! \return True if the matrix has duplicate entries */ /*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
/*! Sort the columns in each row in the ascending order. */ /*!
void CSRSort(CSRMatrix csr); * \brief Sort the column index at each row in the ascending order.
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSort_(&csr)
*
* indptr = [0, 2, 3, 3, 5]
* indices = [0, 1, 1, 2, 3]
*/
void CSRSort_(CSRMatrix* csr);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = CSRRowWiseSampling(csr, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row, col and data indices.
*/
COOMatrix CSRRowWiseSampling(
CSRMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // csr.num_rows = 4;
* // csr.num_cols = 4;
* // csr.indptr = [0, 2, 3, 3, 5]
* // csr.indices = [0, 1, 1, 2, 3]
* // csr.data = [2, 3, 0, 1, 4]
* CSRMatrix csr = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = CSRRowWiseTopk(csr, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input CSR matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix CSRRowWiseTopk(
CSRMatrix mat,
IdArray rows,
int64_t k,
FloatArray weight,
bool ascending = false);
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
...@@ -404,6 +523,104 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray ...@@ -404,6 +523,104 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
/*! \return True if the matrix has duplicate entries */ /*! \return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
/*!
* \brief Sort the indices of a COO matrix.
*
* The function sorts row indices in ascending order. If sort_column is true,
* col indices are sorted in ascending order too. The data array of the returned COOMatrix
* stores the shuffled index which could be used to fetch edge data.
*
* \param mat The input coo matrix
* \param sort_column True if column index should be sorted too.
* \return COO matrix with index sorted.
*/
COOMatrix COOSort(COOMatrix mat, bool sort_column = false);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
* The function performs random choices along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than num_samples,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [1, 3]
* COOMatrix sampled = COORowWiseSampling(coo, rows, 2, FloatArray(), false);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [1, 3, 3]
* // sampled.cols = [1, 2, 3]
* // sampled.data = [3, 0, 4]
*
* \param mat Input coo matrix.
* \param rows Rows to sample from.
* \param num_samples Number of samples
* \param prob Unnormalized probability array. Should be of the same length as the data array.
* If an empty array is provided, assume uniform.
* \param replace True if sample with replacement
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseSampling(
COOMatrix mat,
IdArray rows,
int64_t num_samples,
FloatArray prob = FloatArray(),
bool replace = true);
/*!
* \brief Select K non-zero entries with the largest weights along each given row.
*
* The function performs top-k selection along each row independently.
* The picked indices are returned in the form of a COO matrix.
*
* If replace is false and a row has fewer non-zero values than k,
* all the values are picked.
*
* Examples:
*
* // coo.num_rows = 4;
* // coo.num_cols = 4;
* // coo.rows = [0, 0, 1, 3, 3]
* // coo.cols = [0, 1, 1, 2, 3]
* // coo.data = [2, 3, 0, 1, 4]
* COOMatrix coo = ...;
* IdArray rows = ... ; // [0, 1, 3]
* FloatArray weight = ... ; // [1., 0., -1., 10., 20.]
* COOMatrix sampled = COORowWiseTopk(coo, rows, 1, weight);
* // possible sampled coo matrix:
* // sampled.num_rows = 4
* // sampled.num_cols = 4
* // sampled.rows = [0, 1, 3]
* // sampled.cols = [1, 1, 2]
* // sampled.data = [3, 0, 1]
*
* \param mat Input COO matrix.
* \param rows Rows to sample from.
* \param k The K value.
* \param weight Weight associated with each entry. Should be of the same length as the
* data array. If an empty array is provided, assume uniform.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return A COOMatrix storing the picked row and col indices. Its data field stores the
* the index of the picked elements in the value array.
*/
COOMatrix COORowWiseTopk(
COOMatrix mat,
IdArray rows,
int64_t k,
FloatArray weight,
bool ascending = false);
// inline implementations // inline implementations
template <typename T> template <typename T>
...@@ -421,8 +638,6 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -421,8 +638,6 @@ IdArray VecToIdArray(const std::vector<T>& vec,
return ret.CopyTo(ctx); return ret.CopyTo(ctx);
} }
///////////////////////// Dispatchers ////////////////////////// ///////////////////////// Dispatchers //////////////////////////
/* /*
...@@ -530,40 +745,18 @@ IdArray VecToIdArray(const std::vector<T>& vec, ...@@ -530,40 +745,18 @@ IdArray VecToIdArray(const std::vector<T>& vec,
} \ } \
} while (0) } while (0)
// Macro to dispatch according to device context, index type and data type // Macro to dispatch according to device context and index type.
// TODO(minjie): In our current use cases, data type and id type are the #define ATEN_CSR_SWITCH(csr, XPU, IdType, ...) \
// same. For example, data array is used to store edge ids. ATEN_XPU_SWITCH((csr).indptr->ctx.device_type, XPU, { \
#define ATEN_CSR_SWITCH(csr, XPU, IdType, DType, ...) \ ATEN_ID_TYPE_SWITCH((csr).indptr->dtype, IdType, { \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_COO_SWITCH(coo, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \ {__VA_ARGS__} \
}); \ }); \
}); });
// Macro to dispatch according to device context and index type // Macro to dispatch according to device context and index type.
#define ATEN_COO_IDX_SWITCH(coo, XPU, IdType, ...) \ #define ATEN_COO_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \ ATEN_XPU_SWITCH((coo).row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \ ATEN_ID_TYPE_SWITCH((coo).row->dtype, IdType, { \
{__VA_ARGS__} \ {__VA_ARGS__} \
}); \ }); \
}); });
......
...@@ -21,11 +21,29 @@ namespace dgl { ...@@ -21,11 +21,29 @@ namespace dgl {
// Forward declaration // Forward declaration
class BaseHeteroGraph; class BaseHeteroGraph;
class FlattenedHeteroGraph;
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr; typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;
struct FlattenedHeteroGraph;
typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr; typedef std::shared_ptr<FlattenedHeteroGraph> FlattenedHeteroGraphPtr;
struct HeteroSubgraph; struct HeteroSubgraph;
/*! \brief Enum class for edge direction */
enum class EdgeDir {
kIn, // in edge direction
kOut // out edge direction
};
/*!
* \brief Sparse graph format.
*/
enum class SparseFormat {
ANY = 0,
COO = 1,
CSR = 2,
CSC = 3
};
/*! /*!
* \brief Base heterogenous graph. * \brief Base heterogenous graph.
* *
...@@ -323,6 +341,8 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -323,6 +341,8 @@ class BaseHeteroGraph : public runtime::Object {
/*! /*!
* \brief Get the adjacency matrix of the graph. * \brief Get the adjacency matrix of the graph.
* *
* TODO(minjie): deprecate this interface; replace it with GetXXXMatrix.
*
* By default, a row of returned adjacency matrix represents the destination * By default, a row of returned adjacency matrix represents the destination
* of an edge and the column represents the source. * of an edge and the column represents the source.
* *
...@@ -339,6 +359,49 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -339,6 +359,49 @@ class BaseHeteroGraph : public runtime::Object {
virtual std::vector<IdArray> GetAdj( virtual std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const = 0; dgl_type_t etype, bool transpose, const std::string &fmt) const = 0;
/*!
* \brief Determine which format to use with a preference.
*
* Return the preferred format if the underlying relation graph supports it.
* Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments.
*
* \param etype Edge type.
* \param preferred_format Preferred sparse format.
* \return Available sparse format.
*/
virtual SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const = 0;
/*!
* \brief Get adjacency matrix in COO format.
* \param etype Edge type.
* \return COO matrix.
*/
virtual aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const = 0;
/*!
* \brief Get adjacency matrix in CSR format.
*
* The row and column sizes are equal to the number of dsttype and srctype
* nodes, respectively.
*
* \param etype Edge type.
* \return CSR matrix.
*/
virtual aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const = 0;
/*!
* \brief Get adjacency matrix in CSC format.
*
* A CSC matrix is equivalent to the transpose of a CSR matrix.
* We reuse the CSRMatrix data structure as return value. The row and column
* sizes are equal to the number of dsttype and srctype nodes, respectively.
*
* \param etype Edge type.
* \return A CSR matrix.
*/
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;
/*! /*!
* \brief Extract the induced subgraph by the given vertices. * \brief Extract the induced subgraph by the given vertices.
* *
...@@ -390,24 +453,43 @@ class BaseHeteroGraph : public runtime::Object { ...@@ -390,24 +453,43 @@ class BaseHeteroGraph : public runtime::Object {
// Define HeteroGraphRef // Define HeteroGraphRef
DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph); DGL_DEFINE_OBJECT_REF(HeteroGraphRef, BaseHeteroGraph);
/*! \brief Heter-subgraph data structure */ /*!
* \brief Hetero-subgraph data structure.
*
* This class can be used as arguments and return values of a C API.
*
* <code>
* DGL_REGISTER_GLOBAL("some_c_api")
* .set_body([] (DGLArgs args, DGLRetValue* rv) {
* HeteroSubgraphRef subg = args[0];
* std::shared_ptr<HeteroSubgraph> ret = do_something( ... );
* *rv = HeteroSubgraphRef(ret);
* });
* </code>
*/
struct HeteroSubgraph : public runtime::Object { struct HeteroSubgraph : public runtime::Object {
/*! \brief The heterograph. */ /*! \brief The heterograph. */
HeteroGraphPtr graph; HeteroGraphPtr graph;
/*! /*!
* \brief The induced vertex ids of each entity type. * \brief The induced vertex ids of each entity type.
* The vector length is equal to the number of vertex types in the parent graph. * The vector length is equal to the number of vertex types in the parent graph.
* Each array i has the same length as the number of vertices in type i.
* Empty array is allowed if the mapping is identity.
*/ */
std::vector<IdArray> induced_vertices; std::vector<IdArray> induced_vertices;
/*! /*!
* \brief The induced vertex ids of each entity type. * \brief The induced edge ids of each relation type.
* The vector length is equal to the number of vertex types in the parent graph. * The vector length is equal to the number of edge types in the parent graph.
* Each array i has the same length as the number of edges in type i.
* Empty array is allowed if the mapping is identity.
*/ */
std::vector<IdArray> induced_edges; std::vector<IdArray> induced_edges;
static constexpr const char* _type_key = "graph.HeteroSubgraph"; static constexpr const char* _type_key = "graph.HeteroSubgraph";
DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(HeteroSubgraph, runtime::Object);
}; };
// Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
/*! \brief The flattened heterograph */ /*! \brief The flattened heterograph */
struct FlattenedHeteroGraph : public runtime::Object { struct FlattenedHeteroGraph : public runtime::Object {
...@@ -465,21 +547,8 @@ struct FlattenedHeteroGraph : public runtime::Object { ...@@ -465,21 +547,8 @@ struct FlattenedHeteroGraph : public runtime::Object {
}; };
DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph); DGL_DEFINE_OBJECT_REF(FlattenedHeteroGraphRef, FlattenedHeteroGraph);
// Define HeteroSubgraphRef
DGL_DEFINE_OBJECT_REF(HeteroSubgraphRef, HeteroSubgraph);
// creators // creators
/*!
* \brief Sparse graph format.
*/
enum class SparseFormat {
ANY = 0,
COO = 1,
CSR = 2,
CSC = 3
};
inline SparseFormat ParseSparseFormat(const std::string& name) { inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo") if (name == "coo")
return SparseFormat::COO; return SparseFormat::COO;
...@@ -495,6 +564,36 @@ inline SparseFormat ParseSparseFormat(const std::string& name) { ...@@ -495,6 +564,36 @@ inline SparseFormat ParseSparseFormat(const std::string& name) {
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs);
/*!
* \brief Create a heterograph from COO input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param num_src Number of nodes in the source type.
* \param num_dst Number of nodes in the destination type.
* \param row Src node ids of the edges.
* \param col Dst node ids of the edges.
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY);
/*!
* \brief Create a heterograph from CSR input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param num_src Number of nodes in the source type.
* \param num_dst Number of nodes in the destination type.
* \param indptr Indptr array
* \param indices Indices array
* \param edge_ids Edge ids
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::ANY);
/*! /*!
* \brief Given a list of graphs, remove the common nodes that do not have inbound and * \brief Given a list of graphs, remove the common nodes that do not have inbound and
* outbound edges. * outbound edges.
...@@ -508,6 +607,24 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -508,6 +607,24 @@ HeteroGraphPtr CreateHeteroGraph(
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(const std::vector<HeteroGraphPtr> &graphs); CompactGraphs(const std::vector<HeteroGraphPtr> &graphs);
/*!
* \brief Extract the subgraph of the in edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \return Subgraph containing only the in edges. The returned graph has the same
* schema as the original one.
*/
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes);
/*!
* \brief Extract the subgraph of the out edges of the given nodes.
* \param graph Graph
* \param nodes Node IDs of each type
* \return Subgraph containing only the out edges. The returned graph has the same
* schema as the original one.
*/
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& nodes);
}; // namespace dgl }; // namespace dgl
#endif // DGL_BASE_HETEROGRAPH_H_ #endif // DGL_BASE_HETEROGRAPH_H_
...@@ -240,11 +240,10 @@ class CSR : public GraphInterface { ...@@ -240,11 +240,10 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; } IdArray edge_ids() const { return adj_.data; }
void SortCSR() { void SortCSR() override {
if (adj_.sorted) if (adj_.sorted)
return; return;
aten::CSRSort(adj_); aten::CSRSort_(&adj_);
adj_.sorted = true;
} }
private: private:
......
...@@ -18,6 +18,7 @@ namespace dgl { ...@@ -18,6 +18,7 @@ namespace dgl {
namespace { namespace {
// Get a unique integer ID representing this thread.
inline uint32_t GetThreadId() { inline uint32_t GetThreadId() {
static int num_threads = 0; static int num_threads = 0;
static std::mutex mutex; static std::mutex mutex;
...@@ -92,18 +93,51 @@ class RandomEngine { ...@@ -92,18 +93,51 @@ class RandomEngine {
*/ */
template<typename T> template<typename T>
T Uniform(T lower, T upper) { T Uniform(T lower, T upper) {
CHECK_LT(lower, upper); // Although the result is in [lower, upper), we allow lower == upper as in
// www.cplusplus.com/reference/random/uniform_real_distribution/uniform_real_distribution/
CHECK_LE(lower, upper);
std::uniform_real_distribution<T> dist(lower, upper); std::uniform_real_distribution<T> dist(lower, upper);
return dist(rng_); return dist(rng_);
} }
/*! /*!
* \brief Pick a random integer between 0 to N-1 according to given probabilities * \brief Pick a random integer between 0 to N-1 according to given probabilities
* \param prob Array of unnormalized probability of each element. Must be non-negative. * \tparam IdxType Return integer type
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \return An integer randomly picked from 0 to N-1.
*/ */
template<typename IdxType> template<typename IdxType>
IdxType Choice(FloatArray prob); IdxType Choice(FloatArray prob);
/*!
* \brief Pick random integers between 0 to N-1 according to given probabilities
*
* If replace is false, the number of picked integers must not larger than N.
*
* \tparam IdxType Id type
* \tparam FloatType Probability value type
* \param num Number of integers to choose
* \param prob Array of N unnormalized probability of each element. Must be non-negative.
* \param replace If true, choose with replacement.
* \return Integer array
*/
template <typename IdxType, typename FloatType>
IdArray Choice(int64_t num, FloatArray prob, bool replace = true);
/*!
* \brief Pick random integers from population by uniform distribution.
*
* If replace is false, num must not be larger than population.
*
* \tparam IdxType Return integer type
* \param num Number of integers to choose
* \param population Total number of elements to choose from.
* \param replace If true, choose with replacement.
* \return Integer array
*/
template <typename IdxType>
IdArray UniformChoice(int64_t num, int64_t population, bool replace = true);
private: private:
std::default_random_engine rng_; std::default_random_engine rng_;
}; };
......
...@@ -674,6 +674,21 @@ class Map<std::string, V, T1, T2> : public ObjectRef { ...@@ -674,6 +674,21 @@ class Map<std::string, V, T1, T2> : public ObjectRef {
} }
}; };
/*!
* \brief Helper function to convert a List<Value> object to a vector.
* \tparam T element type
* \param list Input list object.
* \return std vector
*/
template <typename T>
inline std::vector<T> ListValueToVector(const List<Value>& list) {
std::vector<T> ret;
ret.reserve(list.size());
for (Value val : list)
ret.push_back(val->data);
return ret;
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
/*!
* Copyright (c) 2020 by Contributors
* \file dgl/sampling/neighbor.h
* \brief Neighborhood-based sampling.
*/
#ifndef DGL_SAMPLING_NEIGHBOR_H_
#define DGL_SAMPLING_NEIGHBOR_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h>
#include <vector>
namespace dgl {
namespace sampling {
/*!
* \brief Sample from the neighbors of the given nodes and return the sampled edges as a graph.
*
* When sampling with replacement, the sampled subgraph could have parallel edges.
*
* For sampling without replace, if fanout > the number of neighbors, all the
* neighbors will be sampled.
*
* \param hg The input graph.
* \param nodes Node IDs of each type. The vector length must be equal to the number
* of node types. Empty array is allowed.
* \param fanouts Number of sampled neighbors for each edge type. The vector length
* should be equal to the number of edge types, or one if they all
* have the same fanout.
* \param dir Edge direction.
* \param probability A vector of 1D float arrays, indicating the transition probability of
* each edge by edge type. An empty float array assumes uniform transition.
* \param replace If true, sample with replacement.
* \return Sampled neighborhoods as a graph. The return graph has the same schema as the
* original one.
*/
HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& probability,
bool replace = true);
/*!
* Select the neighbors with k-largest weights on the connecting edges for each given node.
*
* If k > the number of neighbors, all the neighbors are sampled.
*
* \param hg The input graph.
* \param nodes Node IDs of each type. The vector length must be equal to the number
* of node types. Empty array is allowed.
* \param k The k value for each edge type. The vector length
* should be equal to the number of edge types, or one if they all
* have the same fanout.
* \param dir Edge direction.
* \param weight A vector of 1D float arrays, indicating the weights associated with
* each edge.
* \param ascending If true, elements are sorted by ascending order, equivalent to find
* the K smallest values. Otherwise, find K largest values.
* \return Sampled neighborhoods as a graph. The return graph has the same schema as the
* original one.
*/
HeteroSubgraph SampleNeighborsTopk(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& k,
EdgeDir dir,
const std::vector<FloatArray>& weight,
bool ascending = false);
} // namespace sampling
} // namespace dgl
#endif // DGL_SAMPLING_NEIGHBOR_H_
"""Classes for heterogeneous graphs.""" """Classes for heterogeneous graphs."""
#pylint: disable= too-many-lines
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
import networkx as nx import networkx as nx
...@@ -617,6 +618,7 @@ class DGLHeteroGraph(object): ...@@ -617,6 +618,7 @@ class DGLHeteroGraph(object):
"to get view of one relation type. Use : to slice multiple types (e.g. " +\ "to get view of one relation type. Use : to slice multiple types (e.g. " +\
"G['srctype', :, 'dsttype'])." "G['srctype', :, 'dsttype'])."
orig_key = key
if not isinstance(key, tuple): if not isinstance(key, tuple):
key = (SLICE_FULL, key, SLICE_FULL) key = (SLICE_FULL, key, SLICE_FULL)
...@@ -624,6 +626,10 @@ class DGLHeteroGraph(object): ...@@ -624,6 +626,10 @@ class DGLHeteroGraph(object):
raise DGLError(err_msg) raise DGLError(err_msg)
etypes = self._find_etypes(key) etypes = self._find_etypes(key)
if len(etypes) == 0:
raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
if len(etypes) == 1: if len(etypes) == 1:
# no ambiguity: return the unitgraph itself # no ambiguity: return the unitgraph itself
srctype, etype, dsttype = self._canonical_etypes[etypes[0]] srctype, etype, dsttype = self._canonical_etypes[etypes[0]]
......
"""Sampler modules.""" """Sampler modules."""
from .randomwalks import * from .randomwalks import *
from .neighbor import *
"""Neighbor sampling APIs"""
from .._ffi.function import _init_api
from .. import backend as F
from ..base import DGLError, EID
from ..heterograph import DGLHeteroGraph
from .. import ndarray as nd
from .. import utils
__all__ = ['sample_neighbors', 'sample_neighbors_topk']
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=True):
"""Sample from the neighbors of the given nodes and return the induced subgraph.
When sampling with replacement, the sampled subgraph could have parallel edges.
For sampling without replace, if fanout > the number of neighbors, all the
neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DGLHeteroGraph
Full graph structure.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
fanout : int or list[int]
The number of sampled neighbors for each node on each edge type. Provide a list
to specify different fanout values for each edge type.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges. Otherwise,
sample from out edges.
prob : str, optional
Feature name used as the probabilities associated with each neighbor of a node.
Its shape should be compatible with a scalar edge feature tensor.
replace : bool, optional
If True, sample with replacement.
Returns
-------
DGLHeteroGraph
A sampled subgraph containing only the sampled neighbor edges from
``nodes``. The sampled subgraph has the same metagraph as the original
one.
"""
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
if not isinstance(fanout, list):
fanout = [int(fanout)] * len(g.etypes)
if len(fanout) != len(g.etypes):
raise DGLError('Fan-out must be specified for each edge type '
'if a list is provided.')
if prob is None:
prob_arrays = [nd.array([], ctx=nd.cpu())] * len(g.etypes)
else:
prob_arrays = []
for etype in g.canonical_etypes:
if prob in g.edges[etype].data:
prob_arrays.append(F.zerocopy_to_dgl_ndarray(g.edges[etype].data[prob]))
else:
prob_arrays.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout,
edge_dir, prob_arrays, replace)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i].tousertensor()
return ret
def sample_neighbors_topk(g, nodes, k, weight, edge_dir='in', ascending=False):
"""Select the neighbors with k-largest weights on the connecting edges for each given node.
If k > the number of neighbors, all the neighbors are sampled.
Node/edge features are not preserved. The original IDs of
the sampled edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DGLHeteroGraph
Full graph structure.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id
tensor if the given graph g has only one type of nodes.
k : int
The K value.
weight : str
Feature name of the weights associated with each edge. Its shape should be
compatible with a scalar edge feature tensor.
edge_dir : str, optional
Edge direction ('in' or 'out'). If is 'in', sample from in edges.
Otherwise, sample from out edges.
ascending : bool, optional
If true, elements are sorted by ascending order, equivalent to find
the K smallest values. Otherwise, find K largest values.
Returns
-------
DGLGraph
A sampled subgraph by top k criterion. The sampled subgraph has the same
metagraph as the original one.
"""
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
if not isinstance(k, list):
k = [int(k)] * len(g.etypes)
if len(k) != len(g.etypes):
raise DGLError('K value must be specified for each edge type '
'if a list is provided.')
weight_arrays = []
for etype in g.canonical_etypes:
if weight in g.edges[etype].data:
weight_arrays.append(F.zerocopy_to_dgl_ndarray(g.edges[etype].data[weight]))
else:
raise DGLError('Edge weights "{}" do not exist for relation graph "{}".'.format(
weight, etype))
subgidx = _CAPI_DGLSampleNeighborsTopk(
g._graph, nodes_all_types, k, edge_dir, weight_arrays, bool(ascending))
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i].tousertensor()
return ret
_init_api('dgl.sampling.neighbor', __name__)
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
import numpy as np import numpy as np
from scipy import sparse from scipy import sparse
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import EID
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from . import ndarray as nd
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
from . import backend as F from . import backend as F
from .graph_index import from_coo from .graph_index import from_coo
...@@ -16,7 +19,7 @@ from . import utils ...@@ -16,7 +19,7 @@ from . import utils
__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected', __all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max', 'knn_graph', 'segmented_knn_graph', 'add_self_loop', 'laplacian_lambda_max', 'knn_graph', 'segmented_knn_graph', 'add_self_loop',
'remove_self_loop', 'metapath_reachable_graph'] 'remove_self_loop', 'metapath_reachable_graph', 'in_subgraph', 'out_subgraph']
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
...@@ -562,4 +565,84 @@ def partition_graph_with_halo(g, node_part, num_hops): ...@@ -562,4 +565,84 @@ def partition_graph_with_halo(g, node_part, num_hops):
subg_dict[i] = subg subg_dict[i] = subg
return subg_dict return subg_dict
def in_subgraph(g, nodes):
"""Extract the subgraph containing only the in edges of the given nodes.
The subgraph keeps the same type schema and the cardinality of the original one.
Node/edge features are not preserved. The original IDs
the extracted edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DGLHeteroGraph
Full graph structure.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
Returns
-------
DGLHeteroGraph
The subgraph.
"""
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLInSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i].tousertensor()
return ret
def out_subgraph(g, nodes):
"""Extract the subgraph containing only the out edges of the given nodes.
The subgraph keeps the same type schema and the cardinality of the original one.
Node/edge features are not preserved. The original IDs
the extracted edges are stored as the `dgl.EID` feature in the returned graph.
Parameters
----------
g : DGLHeteroGraph
Full graph structure.
nodes : tensor or dict
Node ids to sample neighbors from. The allowed types
are dictionary of node types to node id tensors, or simply node id tensor if
the given graph g has only one type of nodes.
Returns
-------
DGLHeteroGraph
The subgraph.
"""
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
raise DGLError("Must specify node type when the graph is not homogeneous.")
nodes = {g.ntypes[0] : nodes}
nodes_all_types = []
for ntype in g.ntypes:
if ntype in nodes:
nodes_all_types.append(utils.toindex(nodes[ntype]).todgltensor())
else:
nodes_all_types.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLOutSubgraph(g._graph, nodes_all_types)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
for i, etype in enumerate(ret.canonical_etypes):
ret.edges[etype].data[EID] = induced_edges[i].tousertensor()
return ret
_init_api("dgl.transform") _init_api("dgl.transform")
...@@ -274,7 +274,7 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) { ...@@ -274,7 +274,7 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
bool ret = false; bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col); ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
}); });
return ret; return ret;
...@@ -282,7 +282,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -282,7 +282,7 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
NDArray ret; NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col); ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
}); });
return ret; return ret;
...@@ -290,7 +290,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -290,7 +290,7 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
bool CSRHasDuplicate(CSRMatrix csr) { bool CSRHasDuplicate(CSRMatrix csr) {
bool ret = false; bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRHasDuplicate<XPU, IdType>(csr); ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
}); });
return ret; return ret;
...@@ -298,7 +298,7 @@ bool CSRHasDuplicate(CSRMatrix csr) { ...@@ -298,7 +298,7 @@ bool CSRHasDuplicate(CSRMatrix csr) {
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) { int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
int64_t ret = 0; int64_t ret = 0;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
}); });
return ret; return ret;
...@@ -306,7 +306,7 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) { ...@@ -306,7 +306,7 @@ int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) { NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray ret; NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row); ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
}); });
return ret; return ret;
...@@ -314,7 +314,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) { ...@@ -314,7 +314,7 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray ret; NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row); ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
}); });
return ret; return ret;
...@@ -322,24 +322,24 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) { ...@@ -322,24 +322,24 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) { NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
NDArray ret; NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowData<XPU, IdType, DType>(csr, row); ret = impl::CSRGetRowData<XPU, IdType>(csr, row);
}); });
return ret; return ret;
} }
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) { NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray ret; NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, row, col); ret = impl::CSRGetData<XPU, IdType>(csr, row, col);
}); });
return ret; return ret;
} }
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret; NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols); ret = impl::CSRGetData<XPU, IdType>(csr, rows, cols);
}); });
return ret; return ret;
} }
...@@ -347,16 +347,16 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -347,16 +347,16 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> CSRGetDataAndIndices( std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) { CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> ret; std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetDataAndIndices<XPU, IdType, DType>(csr, rows, cols); ret = impl::CSRGetDataAndIndices<XPU, IdType>(csr, rows, cols);
}); });
return ret; return ret;
} }
CSRMatrix CSRTranspose(CSRMatrix csr) { CSRMatrix CSRTranspose(CSRMatrix csr) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRTranspose<XPU, IdType, DType>(csr); ret = impl::CSRTranspose<XPU, IdType>(csr);
}); });
return ret; return ret;
} }
...@@ -381,39 +381,67 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) { ...@@ -381,39 +381,67 @@ COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) { CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, start, end); ret = impl::CSRSliceRows<XPU, IdType>(csr, start, end);
}); });
return ret; return ret;
} }
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, rows); ret = impl::CSRSliceRows<XPU, IdType>(csr, rows);
}); });
return ret; return ret;
} }
CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) { CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CSRMatrix ret; CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRSliceMatrix<XPU, IdType, DType>(csr, rows, cols); ret = impl::CSRSliceMatrix<XPU, IdType>(csr, rows, cols);
}); });
return ret; return ret;
} }
void CSRSort(CSRMatrix csr) { void CSRSort_(CSRMatrix* csr) {
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, { ATEN_CSR_SWITCH(*csr, XPU, IdType, {
impl::CSRSort<XPU, IdType, DType>(csr); impl::CSRSort_<XPU, IdType>(csr);
}); });
} }
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, {
if (!prob.defined() || prob->shape[0] == 0) {
ret = impl::CSRRowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::CSRRowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace);
});
}
});
return ret;
}
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
COOMatrix ret;
ATEN_CSR_SWITCH(mat, XPU, IdType, {
ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
ret = impl::CSRRowWiseTopk<XPU, IdType, FloatType>(
mat, rows, k, weight, ascending);
});
});
return ret;
}
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
bool ret = false; bool ret = false;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col); ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
}); });
return ret; return ret;
...@@ -421,7 +449,7 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) { ...@@ -421,7 +449,7 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
NDArray ret; NDArray ret;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col); ret = impl::COOIsNonZero<XPU, IdType>(coo, row, col);
}); });
return ret; return ret;
...@@ -429,7 +457,7 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) { ...@@ -429,7 +457,7 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
bool COOHasDuplicate(COOMatrix coo) { bool COOHasDuplicate(COOMatrix coo) {
bool ret = false; bool ret = false;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOHasDuplicate<XPU, IdType>(coo); ret = impl::COOHasDuplicate<XPU, IdType>(coo);
}); });
return ret; return ret;
...@@ -437,7 +465,7 @@ bool COOHasDuplicate(COOMatrix coo) { ...@@ -437,7 +465,7 @@ bool COOHasDuplicate(COOMatrix coo) {
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
int64_t ret = 0; int64_t ret = 0;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row); ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
}); });
return ret; return ret;
...@@ -445,7 +473,7 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) { ...@@ -445,7 +473,7 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) { NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
NDArray ret; NDArray ret;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row); ret = impl::COOGetRowNNZ<XPU, IdType>(coo, row);
}); });
return ret; return ret;
...@@ -453,16 +481,16 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) { ...@@ -453,16 +481,16 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray row) {
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row) { std::pair<NDArray, NDArray> COOGetRowDataAndIndices(COOMatrix coo, int64_t row) {
std::pair<NDArray, NDArray> ret; std::pair<NDArray, NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOGetRowDataAndIndices<XPU, IdType, DType>(coo, row); ret = impl::COOGetRowDataAndIndices<XPU, IdType>(coo, row);
}); });
return ret; return ret;
} }
NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) { NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
NDArray ret; NDArray ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOGetData<XPU, IdType, DType>(coo, row, col); ret = impl::COOGetData<XPU, IdType>(coo, row, col);
}); });
return ret; return ret;
} }
...@@ -470,48 +498,84 @@ NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) { ...@@ -470,48 +498,84 @@ NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
std::vector<NDArray> COOGetDataAndIndices( std::vector<NDArray> COOGetDataAndIndices(
COOMatrix coo, NDArray rows, NDArray cols) { COOMatrix coo, NDArray rows, NDArray cols) {
std::vector<NDArray> ret; std::vector<NDArray> ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOGetDataAndIndices<XPU, IdType, DType>(coo, rows, cols); ret = impl::COOGetDataAndIndices<XPU, IdType>(coo, rows, cols);
}); });
return ret; return ret;
} }
COOMatrix COOTranspose(COOMatrix coo) { COOMatrix COOTranspose(COOMatrix coo) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOTranspose<XPU, IdType, DType>(coo); ret = impl::COOTranspose<XPU, IdType>(coo);
}); });
return ret; return ret;
} }
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret; CSRMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOToCSR<XPU, IdType, DType>(coo); ret = impl::COOToCSR<XPU, IdType>(coo);
}); });
return ret; return ret;
} }
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOSliceRows<XPU, IdType, DType>(coo, start, end); ret = impl::COOSliceRows<XPU, IdType>(coo, start, end);
}); });
return ret; return ret;
} }
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOSliceRows<XPU, IdType, DType>(coo, rows); ret = impl::COOSliceRows<XPU, IdType>(coo, rows);
}); });
return ret; return ret;
} }
COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) { COOMatrix COOSliceMatrix(COOMatrix coo, NDArray rows, NDArray cols) {
COOMatrix ret; COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, { ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COOSliceMatrix<XPU, IdType, DType>(coo, rows, cols); ret = impl::COOSliceMatrix<XPU, IdType>(coo, rows, cols);
});
return ret;
}
COOMatrix COOSort(COOMatrix mat, bool sort_column) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
ret = impl::COOSort<XPU, IdType>(mat, sort_column);
});
return ret;
}
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
if (!prob.defined() || prob->shape[0] == 0) {
ret = impl::COORowWiseSamplingUniform<XPU, IdType>(mat, rows, num_samples, replace);
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
ret = impl::COORowWiseSampling<XPU, IdType, FloatType>(
mat, rows, num_samples, prob, replace);
});
}
});
return ret;
}
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
COOMatrix ret;
ATEN_COO_SWITCH(mat, XPU, IdType, {
ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
ret = impl::COORowWiseTopk<XPU, IdType, FloatType>(
mat, rows, k, weight, ascending);
});
}); });
return ret; return ret;
} }
......
...@@ -71,20 +71,20 @@ runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row); ...@@ -71,20 +71,20 @@ runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row); runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row); runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col); runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> CSRGetDataAndIndices( std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr); CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO // Convert CSR to COO
...@@ -95,17 +95,33 @@ COOMatrix CSRToCOO(CSRMatrix csr); ...@@ -95,17 +95,33 @@ COOMatrix CSRToCOO(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr); COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end); CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
void CSRSort(CSRMatrix csr); void CSRSort_(CSRMatrix* csr);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
///////////////////////////////////////////////////////////////////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col); bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
...@@ -122,32 +138,48 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row); ...@@ -122,32 +138,48 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row); runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray> std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row); COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col); runtime::NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices( std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo); COOMatrix COOTranspose(COOMatrix coo);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo); CSRMatrix COOToCSR(COOMatrix coo);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end); COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows); COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
COOMatrix COOSort(COOMatrix mat, bool sort_column);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
template <DLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform(
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_sort.cc
* \brief COO sorting
*/
#include <dgl/array.h>
#include <numeric>
#include <algorithm>
#include <vector>
namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
COOMatrix COOSort(COOMatrix coo, bool sort_column) {
const int64_t nnz = coo.row->shape[0];
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
// Argsort
IdArray new_row = IdArray::Empty({nnz}, coo.row->dtype, coo.row->ctx);
IdArray new_col = IdArray::Empty({nnz}, coo.col->dtype, coo.col->ctx);
IdArray new_data = IdArray::Empty({nnz}, coo.row->dtype, coo.row->ctx);
IdType* new_row_data = static_cast<IdType*>(new_row->data);
IdType* new_col_data = static_cast<IdType*>(new_col->data);
IdType* new_data_data = static_cast<IdType*>(new_data->data);
std::iota(new_data_data, new_data_data + nnz, 0);
if (sort_column) {
std::sort(
new_data_data,
new_data_data + nnz,
[coo_row_data, coo_col_data](IdType a, IdType b) {
return (coo_row_data[a] != coo_row_data[b]) ?
(coo_row_data[a] < coo_row_data[b]) :
(coo_col_data[a] < coo_col_data[b]);
});
} else {
std::sort(
new_data_data,
new_data_data + nnz,
[coo_row_data](IdType a, IdType b) {
return coo_row_data[a] <= coo_row_data[b];
});
}
// Reorder according to shuffle
for (IdType i = 0; i < nnz; ++i) {
new_row_data[i] = coo_row_data[new_data_data[i]];
new_col_data[i] = coo_col_data[new_data_data[i]];
}
return COOMatrix{
coo.num_rows, coo.num_cols, std::move(new_row), std::move(new_col),
std::move(new_data), true, sort_column};
}
template COOMatrix COOSort<kDLCPU, int32_t>(COOMatrix, bool);
template COOMatrix COOSort<kDLCPU, int64_t>(COOMatrix, bool);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/rowwise_pick.h
* \brief Template implementation for rowwise pick operators.
*/
#ifndef DGL_ARRAY_CPU_ROWWISE_PICK_H_
#define DGL_ARRAY_CPU_ROWWISE_PICK_H_
#include <dgl/array.h>
#include <functional>
namespace dgl {
namespace aten {
namespace impl {
// User-defined function for picking elements from one row.
//
// The column indices of the given row are stored in
// [col + off, col + off + len)
//
// Similarly, the data indices are stored in
// [data + off, data + off + len)
// Data index pointer could be NULL, which means data[i] == i
//
// *ATTENTION*: This function will be invoked concurrently. Please make sure
// it is thread-safe.
//
// \param rowid The row to pick from.
// \param off Starting offset of this row.
// \param len NNZ of the row.
// \param col Pointer of the column indices.
// \param data Pointer of the data indices.
// \param out_idx Picked indices in [off, off + len).
template <typename IdxType>
using PickFn = std::function<void(
IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data,
IdxType* out_idx)>;
// Template for picking non-zero values row-wise. The implementation utilizes
// OpenMP parallelization on rows because each row performs computation independently.
template <typename IdxType>
COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn) {
using namespace aten;
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
const IdxType* data = CSRHasData(mat)? static_cast<IdxType*>(mat.data->data) : nullptr;
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
const int64_t num_rows = rows->shape[0];
const auto& ctx = mat.indptr->ctx;
// To leverage OMP parallelization, we create two arrays to store
// picked src and dst indices. Each array is of length num_rows * num_picks.
// For rows whose nnz < num_picks, the indices are padded with -1.
//
// We check whether all the given rows
// have at least num_picks number of nnz when replace is false.
//
// If the check holds, remove -1 elements by remove_if operation, which simply
// moves valid elements to the head of arrays and create a view of the original
// array. The implementation consumes a little extra memory than the actual requirement.
//
// Otherwise, directly use the row and col arrays to construct the result COO matrix.
IdArray picked_row = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdArray picked_col = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdArray picked_idx = Full(-1, num_rows * num_picks, sizeof(IdxType) * 8, ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
IdxType* picked_idata = static_cast<IdxType*>(picked_idx->data);
bool all_has_fanout = true;
if (replace) {
all_has_fanout = true;
} else {
#pragma omp parallel for reduction(&&:all_has_fanout)
for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i];
const IdxType len = indptr[rid + 1] - indptr[rid];
all_has_fanout = all_has_fanout && (len >= num_picks);
}
}
#pragma omp parallel for
for (int64_t i = 0; i < num_rows; ++i) {
const IdxType rid = rows_data[i];
CHECK_LT(rid, mat.num_rows);
const IdxType off = indptr[rid];
const IdxType len = indptr[rid + 1] - off;
if (len <= num_picks && !replace) {
// nnz <= num_picks and w/o replacement, take all nnz
for (int64_t j = 0; j < len; ++j) {
picked_rdata[i * num_picks + j] = rid;
picked_cdata[i * num_picks + j] = indices[off + j];
picked_idata[i * num_picks + j] = data? data[off + j] : off + j;
}
} else {
pick_fn(rid, off, len,
indices, data,
picked_idata + i * num_picks);
for (int64_t j = 0; j < num_picks; ++j) {
const IdxType picked = picked_idata[i * num_picks + j];
picked_rdata[i * num_picks + j] = rid;
picked_cdata[i * num_picks + j] = indices[picked];
picked_idata[i * num_picks + j] = data? data[picked] : picked;
}
}
}
if (!all_has_fanout) {
// correct the array by remove_if
IdxType* new_row_end = std::remove_if(picked_rdata, picked_rdata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
IdxType* new_col_end = std::remove_if(picked_cdata, picked_cdata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
IdxType* new_idx_end = std::remove_if(picked_idata, picked_idata + num_rows * num_picks,
[] (IdxType i) { return i == -1; });
const int64_t new_len = (new_row_end - picked_rdata);
CHECK_EQ(new_col_end - picked_cdata, new_len);
CHECK_EQ(new_idx_end - picked_idata, new_len);
picked_row = picked_row.CreateView({new_len}, picked_row->dtype);
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
}
return COOMatrix(mat.num_rows, mat.num_cols,
picked_row, picked_col, picked_idx);
}
// Template for picking non-zero values row-wise. The implementation first slices
// out the corresponding rows and then converts it to CSR format. It then performs
// row-wise pick on the CSR matrix and rectifies the returned results.
template <typename IdxType>
COOMatrix COORowWisePick(COOMatrix mat, IdArray rows,
int64_t num_picks, bool replace, PickFn<IdxType> pick_fn) {
using namespace aten;
const auto& csr = COOToCSR(COOSliceRows(mat, rows));
const IdArray new_rows = Range(0, rows->shape[0], rows->dtype.bits, rows->ctx);
const auto& picked = CSRRowWisePick<IdxType>(csr, new_rows, num_picks, replace, pick_fn);
return COOMatrix(mat.num_rows, mat.num_cols,
IndexSelect(rows, picked.row), // map the row index to the correct one
picked.col,
picked.data);
}
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_CPU_ROWWISE_PICK_H_
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/rowwise_sampling.cc
* \brief rowwise sampling
*/
#include <dgl/random.h>
#include <numeric>
#include "./rowwise_pick.h"
namespace dgl {
namespace aten {
namespace impl {
namespace {
// Equivalent to numpy expression: array[idx[off:off + len]]
template <typename IdxType, typename FloatType>
inline FloatArray DoubleSlice(FloatArray array, const IdxType* idx_data,
IdxType off, IdxType len) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
FloatArray ret = FloatArray::Empty({len}, array->dtype, array->ctx);
FloatType* ret_data = static_cast<FloatType*>(ret->data);
for (int64_t j = 0; j < len; ++j) {
if (idx_data)
ret_data[j] = array_data[idx_data[off + j]];
else
ret_data[j] = array_data[off + j];
}
return ret;
}
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetSamplingPickFn(
int64_t num_samples, FloatArray prob, bool replace) {
PickFn<IdxType> pick_fn = [prob, num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
// TODO(minjie): If efficiency is a problem, consider avoid creating
// explicit NDArrays by directly manipulating buffers.
FloatArray prob_selected = DoubleSlice<IdxType, FloatType>(prob, data, off, len);
IdArray sampled = RandomEngine::ThreadLocal()->Choice<IdxType, FloatType>(
num_samples, prob_selected, replace);
const IdxType* sampled_data = static_cast<IdxType*>(sampled->data);
for (int64_t j = 0; j < num_samples; ++j) {
out_idx[j] = off + sampled_data[j];
}
};
return pick_fn;
}
template <typename IdxType>
inline PickFn<IdxType> GetSamplingUniformPickFn(
int64_t num_samples, bool replace) {
PickFn<IdxType> pick_fn = [num_samples, replace]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
// TODO(minjie): If efficiency is a problem, consider avoid creating
// explicit NDArrays by directly manipulating buffers.
IdArray sampled = RandomEngine::ThreadLocal()->UniformChoice<IdxType>(
num_samples, len, replace);
const IdxType* sampled_data = static_cast<IdxType*>(sampled->data);
for (int64_t j = 0; j < num_samples; ++j) {
out_idx[j] = off + sampled_data[j];
}
};
return pick_fn;
}
} // namespace
/////////////////////////////// CSR ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
/////////////////////////////// COO ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
auto pick_fn = GetSamplingPickFn<IdxType, FloatType>(num_samples, prob, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>(
COOMatrix, IdArray, int64_t, bool);
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
COOMatrix, IdArray, int64_t, bool);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/rowwise_topk.cc
* \brief rowwise topk
*/
#include <numeric>
#include <algorithm>
#include "./rowwise_pick.h"
namespace dgl {
namespace aten {
namespace impl {
namespace {
template <typename IdxType, typename FloatType>
inline PickFn<IdxType> GetTopkPickFn(int64_t k, FloatArray weight, bool ascending) {
const FloatType* wdata = static_cast<FloatType*>(weight->data);
PickFn<IdxType> pick_fn = [k, ascending, wdata]
(IdxType rowid, IdxType off, IdxType len,
const IdxType* col, const IdxType* data,
IdxType* out_idx) {
std::function<bool(IdxType, IdxType)> compare_fn;
if (ascending) {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[data[i]] < wdata[data[j]];
};
} else {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[i] < wdata[j];
};
}
} else {
if (data) {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[data[i]] > wdata[data[j]];
};
} else {
compare_fn = [wdata, data] (IdxType i, IdxType j) {
return wdata[i] > wdata[j];
};
}
}
std::vector<IdxType> idx(len);
std::iota(idx.begin(), idx.end(), off);
std::sort(idx.begin(), idx.end(), compare_fn);
for (int64_t j = 0; j < k; ++j) {
out_idx[j] = idx[j];
}
};
return pick_fn;
}
} // namespace
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, FloatType>(k, weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, FloatType>(k, weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl
} // namespace aten
} // namespace dgl
This diff is collapsed.
...@@ -121,17 +121,17 @@ template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray); ...@@ -121,17 +121,17 @@ template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOGetRowDataAndIndices ///////////////////////////// ///////////////////////////// COOGetRowDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
std::pair<NDArray, NDArray> COOGetRowDataAndIndices( std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) { COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* coo_data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
std::vector<IdType> indices; std::vector<IdType> indices;
std::vector<DType> data; std::vector<IdType> data;
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
if (coo_row_data[i] == row) { if (coo_row_data[i] == row) {
...@@ -144,20 +144,20 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices( ...@@ -144,20 +144,20 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
} }
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int32_t, int32_t>(COOMatrix, int64_t); COOGetRowDataAndIndices<kDLCPU, int32_t>(COOMatrix, int64_t);
template std::pair<NDArray, NDArray> template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int64_t, int64_t>(COOMatrix, int64_t); COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t);
///////////////////////////// COOGetData ///////////////////////////// ///////////////////////////// COOGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) { NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row; CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col; CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
std::vector<DType> ret_vec; std::vector<IdType> ret_vec;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
for (IdType i = 0; i < coo.row->shape[0]; ++i) { for (IdType i = 0; i < coo.row->shape[0]; ++i) {
if (coo_row_data[i] == row && coo_col_data[i] == col) if (coo_row_data[i] == row && coo_col_data[i] == col)
ret_vec.push_back(data ? data[i] : i); ret_vec.push_back(data ? data[i] : i);
...@@ -165,12 +165,12 @@ NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) { ...@@ -165,12 +165,12 @@ NDArray COOGetData(COOMatrix coo, int64_t row, int64_t col) {
return NDArray::FromVector(ret_vec); return NDArray::FromVector(ret_vec);
} }
template NDArray COOGetData<kDLCPU, int32_t, int32_t>(COOMatrix, int64_t, int64_t); template NDArray COOGetData<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template NDArray COOGetData<kDLCPU, int64_t, int64_t>(COOMatrix, int64_t, int64_t); template NDArray COOGetData<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
///////////////////////////// COOGetDataAndIndices ///////////////////////////// ///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices( std::vector<NDArray> COOGetDataAndIndices(
COOMatrix coo, NDArray rows, NDArray cols) { COOMatrix coo, NDArray rows, NDArray cols) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
...@@ -186,10 +186,10 @@ std::vector<NDArray> COOGetDataAndIndices( ...@@ -186,10 +186,10 @@ std::vector<NDArray> COOGetDataAndIndices(
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
std::vector<IdType> ret_rows, ret_cols; std::vector<IdType> ret_rows, ret_cols;
std::vector<DType> ret_data; std::vector<IdType> ret_data;
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) { for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
...@@ -209,42 +209,38 @@ std::vector<NDArray> COOGetDataAndIndices( ...@@ -209,42 +209,38 @@ std::vector<NDArray> COOGetDataAndIndices(
NDArray::FromVector(ret_data)}; NDArray::FromVector(ret_data)};
} }
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t, int32_t>( template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t>(
COOMatrix coo, NDArray rows, NDArray cols); COOMatrix coo, NDArray rows, NDArray cols);
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t, int64_t>( template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t>(
COOMatrix coo, NDArray rows, NDArray cols); COOMatrix coo, NDArray rows, NDArray cols);
///////////////////////////// COOTranspose ///////////////////////////// ///////////////////////////// COOTranspose /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo) { COOMatrix COOTranspose(COOMatrix coo) {
return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data}; return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
} }
template COOMatrix COOTranspose<kDLCPU, int32_t, int32_t>(COOMatrix coo); template COOMatrix COOTranspose<kDLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDLCPU, int64_t, int64_t>(COOMatrix coo); template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR ///////////////////////////// ///////////////////////////// COOToCSR /////////////////////////////
// complexity: time O(NNZ), space O(1) // complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) { CSRMatrix COOToCSR(COOMatrix coo) {
const int64_t N = coo.num_rows; const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0]; const int64_t NNZ = coo.row->shape[0];
const IdType* row_data = static_cast<IdType*>(coo.row->data); const IdType* row_data = static_cast<IdType*>(coo.row->data);
const IdType* col_data = static_cast<IdType*>(coo.col->data); const IdType* col_data = static_cast<IdType*>(coo.col->data);
const IdType* data = COOHasData(coo)? static_cast<IdType*>(coo.data->data) : nullptr;
NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx); NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);
NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx); NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
NDArray ret_data; NDArray ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
if (COOHasData(coo)) {
ret_data = NDArray::Empty({NNZ}, coo.data->dtype, coo.data->ctx);
} else {
// if no data array in the input coo, the return data array is a shuffle index.
ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
}
IdType* Bp = static_cast<IdType*>(ret_indptr->data); IdType* Bp = static_cast<IdType*>(ret_indptr->data);
IdType* Bi = static_cast<IdType*>(ret_indices->data); IdType* Bi = static_cast<IdType*>(ret_indices->data);
IdType* Bx = static_cast<IdType*>(ret_data->data);
std::fill(Bp, Bp + N, 0); std::fill(Bp, Bp + N, 0);
...@@ -263,14 +259,7 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -263,14 +259,7 @@ CSRMatrix COOToCSR(COOMatrix coo) {
for (int64_t i = 0; i < NNZ; ++i) { for (int64_t i = 0; i < NNZ; ++i) {
const IdType r = row_data[i]; const IdType r = row_data[i];
Bi[Bp[r]] = col_data[i]; Bi[Bp[r]] = col_data[i];
if (COOHasData(coo)) { Bx[Bp[r]] = data? data[i] : i;
const DType* data = static_cast<DType*>(coo.data->data);
DType* Bx = static_cast<DType*>(ret_data->data);
Bx[Bp[r]] = data[i];
} else {
IdType* Bx = static_cast<IdType*>(ret_data->data);
Bx[Bp[r]] = i;
}
Bp[r]++; Bp[r]++;
} }
...@@ -281,25 +270,28 @@ CSRMatrix COOToCSR(COOMatrix coo) { ...@@ -281,25 +270,28 @@ CSRMatrix COOToCSR(COOMatrix coo) {
last = temp; last = temp;
} }
return CSRMatrix{coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data}; return CSRMatrix(coo.num_rows, coo.num_cols,
ret_indptr, ret_indices, ret_data,
coo.col_sorted);
} }
template CSRMatrix COOToCSR<kDLCPU, int32_t, int32_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLCPU, int64_t, int64_t>(COOMatrix coo); template CSRMatrix COOToCSR<kDLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOSliceRows ///////////////////////////// ///////////////////////////// COOSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
// TODO(minjie): use binary search when coo.row_sorted is true
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start; CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end; CHECK(end > 0 && end <= coo.num_rows) << "Invalid end row " << end;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* coo_data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
std::vector<IdType> ret_row, ret_col; std::vector<IdType> ret_row, ret_col;
std::vector<DType> ret_data; std::vector<IdType> ret_data;
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const IdType row_id = coo_row_data[i]; const IdType row_id = coo_row_data[i];
...@@ -310,25 +302,27 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) { ...@@ -310,25 +302,27 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
ret_data.push_back(coo_data ? coo_data[i] : i); ret_data.push_back(coo_data ? coo_data[i] : i);
} }
} }
return COOMatrix{ return COOMatrix(
end - start, end - start,
coo.num_cols, coo.num_cols,
NDArray::FromVector(ret_row), NDArray::FromVector(ret_row),
NDArray::FromVector(ret_col), NDArray::FromVector(ret_col),
NDArray::FromVector(ret_data)}; NDArray::FromVector(ret_data),
coo.row_sorted,
coo.col_sorted);
} }
template COOMatrix COOSliceRows<kDLCPU, int32_t, int32_t>(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDLCPU, int64_t, int64_t>(COOMatrix, int64_t, int64_t); template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* coo_data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
std::vector<IdType> ret_row, ret_col; std::vector<IdType> ret_row, ret_col;
std::vector<DType> ret_data; std::vector<IdType> ret_data;
IdHashMap<IdType> hashmap(rows); IdHashMap<IdType> hashmap(rows);
...@@ -348,24 +342,25 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) { ...@@ -348,24 +342,25 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
coo.num_cols, coo.num_cols,
NDArray::FromVector(ret_row), NDArray::FromVector(ret_row),
NDArray::FromVector(ret_col), NDArray::FromVector(ret_col),
NDArray::FromVector(ret_data)}; NDArray::FromVector(ret_data),
coo.row_sorted, coo.col_sorted};
} }
template COOMatrix COOSliceRows<kDLCPU, int32_t, int32_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDLCPU, int64_t, int64_t>(COOMatrix , NDArray); template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix , NDArray);
///////////////////////////// COOSliceMatrix ///////////////////////////// ///////////////////////////// COOSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) { COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data); const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data); const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
const DType* coo_data = COOHasData(coo) ? static_cast<DType*>(coo.data->data) : nullptr; const IdType* coo_data = COOHasData(coo) ? static_cast<IdType*>(coo.data->data) : nullptr;
IdHashMap<IdType> row_map(rows), col_map(cols); IdHashMap<IdType> row_map(rows), col_map(cols);
std::vector<IdType> ret_row, ret_col; std::vector<IdType> ret_row, ret_col;
std::vector<DType> ret_data; std::vector<IdType> ret_data;
for (int64_t i = 0; i < coo.row->shape[0]; ++i) { for (int64_t i = 0; i < coo.row->shape[0]; ++i) {
const IdType row_id = coo_row_data[i]; const IdType row_id = coo_row_data[i];
...@@ -381,17 +376,16 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray ...@@ -381,17 +376,16 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
} }
} }
return COOMatrix{ return COOMatrix(rows->shape[0], cols->shape[0],
rows->shape[0], NDArray::FromVector(ret_row),
cols->shape[0], NDArray::FromVector(ret_col),
NDArray::FromVector(ret_row), NDArray::FromVector(ret_data),
NDArray::FromVector(ret_col), coo.row_sorted, coo.col_sorted);
NDArray::FromVector(ret_data)};
} }
template COOMatrix COOSliceMatrix<kDLCPU, int32_t, int32_t>( template COOMatrix COOSliceMatrix<kDLCPU, int32_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template COOMatrix COOSliceMatrix<kDLCPU, int64_t, int64_t>( template COOMatrix COOSliceMatrix<kDLCPU, int64_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols); COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
} // namespace impl } // namespace impl
......
...@@ -178,7 +178,7 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) { ...@@ -178,7 +178,7 @@ CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) {
} // namespace } // namespace
HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs)
: BaseHeteroGraph(meta_graph), relation_graphs_(rel_graphs) { : BaseHeteroGraph(meta_graph) {
// Sanity check // Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
...@@ -218,6 +218,17 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -218,6 +218,17 @@ HeteroGraph::HeteroGraph(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
CHECK_EQ(num_verts_per_type_[dsttype], nv) CHECK_EQ(num_verts_per_type_[dsttype], nv)
<< "Mismatch number of vertices for vertex type " << dsttype; << "Mismatch number of vertices for vertex type " << dsttype;
} }
relation_graphs_.resize(rel_graphs.size());
for (size_t i = 0; i < rel_graphs.size(); ++i) {
HeteroGraphPtr relg = rel_graphs[i];
if (std::dynamic_pointer_cast<UnitGraph>(relg)) {
relation_graphs_[i] = std::dynamic_pointer_cast<UnitGraph>(relg);
} else {
relation_graphs_[i] = CHECK_NOTNULL(
std::dynamic_pointer_cast<UnitGraph>(relg->GetRelationGraph(0)));
}
}
} }
bool HeteroGraph::IsMultigraph() const { bool HeteroGraph::IsMultigraph() const {
...@@ -366,6 +377,76 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp ...@@ -366,6 +377,76 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
return FlattenedHeteroGraphPtr(result); return FlattenedHeteroGraphPtr(result);
} }
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
auto relgraph = graph->GetRelationGraph(etype);
if (vids[dst_vtype]->shape[0] == 0) {
// create a placeholder graph
subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
auto relgraph = graph->GetRelationGraph(etype);
if (vids[src_vtype]->shape[0] == 0) {
// create a placeholder graph
subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty"; CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
...@@ -491,6 +572,23 @@ HeteroGraphPtr CreateHeteroGraph( ...@@ -491,6 +572,23 @@ HeteroGraphPtr CreateHeteroGraph(
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs)); return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
} }
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>>
CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) { CompactGraphs(const std::vector<HeteroGraphPtr> &graphs) {
std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result; std::pair<std::vector<HeteroGraphPtr>, std::vector<IdArray>> result;
...@@ -543,8 +641,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") ...@@ -543,8 +641,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
IdArray row = args[3]; IdArray row = args[3];
IdArray col = args[4]; IdArray col = args[4];
SparseFormat restrict_format = ParseSparseFormat(args[5]); SparseFormat restrict_format = ParseSparseFormat(args[5]);
auto hgptr = UnitGraph::CreateFromCOO( auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, restrict_format);
nvtypes, num_src, num_dst, row, col, restrict_format);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
...@@ -557,8 +654,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR") ...@@ -557,8 +654,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
IdArray indices = args[4]; IdArray indices = args[4];
IdArray edge_ids = args[5]; IdArray edge_ids = args[5];
SparseFormat restrict_format = ParseSparseFormat(args[6]); SparseFormat restrict_format = ParseSparseFormat(args[6]);
auto hgptr = UnitGraph::CreateFromCSR( auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids,
nvtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format); restrict_format);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
...@@ -924,6 +1021,24 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCompactGraphs") ...@@ -924,6 +1021,24 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCompactGraphs")
*rv = result; *rv = result;
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = InEdgeGraph(hg.sptr(), nodes);
*rv = HeteroGraphRef(ret);
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = OutEdgeGraph(hg.sptr(), nodes);
*rv = HeteroGraphRef(ret);
});
// HeteroSubgraph C APIs // HeteroSubgraph C APIs
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
#include "./unit_graph.h"
namespace dgl { namespace dgl {
...@@ -163,6 +164,22 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -163,6 +164,22 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->GetAdj(0, transpose, fmt); return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
} }
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCOOMatrix(0);
}
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCSCMatrix(0);
}
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCSRMatrix(0);
}
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_format);
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override; HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
...@@ -176,7 +193,6 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -176,7 +193,6 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \return Save HeteroGraph to stream, using CSRMatrix */ /*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
...@@ -185,7 +201,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -185,7 +201,7 @@ class HeteroGraph : public BaseHeteroGraph {
HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {} HeteroGraph() : BaseHeteroGraph(static_cast<GraphPtr>(nullptr)) {}
/*! \brief A map from edge type to unit graph */ /*! \brief A map from edge type to unit graph */
std::vector<HeteroGraphPtr> relation_graphs_; std::vector<UnitGraphPtr> relation_graphs_;
/*! \brief A map from vert type to the number of verts in the type */ /*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_; std::vector<int64_t> num_verts_per_type_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment