"vscode:/vscode.git/clone" did not exist on "fe8a163986672bbbec1a922231be229cc79dafe6"
Unverified Commit b0d9e7aa authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Separating graph and sparse matrix operations (#699)

* WIP: array refactoring

* WIP: implementation

* wip

* most csr part

* WIP: on coo

* WIP: coo

* finish refactoring immutable graph

* compiled

* fix undefined ndarray copy bug; add COOToCSR when coo has no data array

* fix bug in COOToCSR

* fix bug in CSR constructor

* fix bug in in_edges(vid)

* fix OutEdges bug

* pass test_graph

* pass test_graph

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix stupid bug

* pass gpu test

* remove debug printout

* fix lint

* rm biparate grpah

* fix lint

* address comments

* fix bug in Clone

* cpp utests
parent f79188da
...@@ -87,6 +87,8 @@ add_definitions(-DENABLE_PARTIAL_FRONTIER=0) # disable minigun partial frontier ...@@ -87,6 +87,8 @@ add_definitions(-DENABLE_PARTIAL_FRONTIER=0) # disable minigun partial frontier
# Source file lists # Source file lists
file(GLOB DGL_SRC file(GLOB DGL_SRC
src/*.cc src/*.cc
src/array/*.cc
src/array/cpu/*.cc
src/kernel/*.cc src/kernel/*.cc
src/kernel/cpu/*.cc src/kernel/cpu/*.cc
src/runtime/*.cc src/runtime/*.cc
......
...@@ -231,6 +231,8 @@ macro(dgl_config_cuda out_variable) ...@@ -231,6 +231,8 @@ macro(dgl_config_cuda out_variable)
add_definitions(-DDGL_USE_CUDA) add_definitions(-DDGL_USE_CUDA)
file(GLOB_RECURSE DGL_CUDA_SRC file(GLOB_RECURSE DGL_CUDA_SRC
src/array/cuda/*.cc
src/array/cuda/*.cu
src/kernel/cuda/*.cc src/kernel/cuda/*.cc
src/kernel/cuda/*.cu src/kernel/cuda/*.cu
src/runtime/cuda/*.cc src/runtime/cuda/*.cc
......
...@@ -10,7 +10,9 @@ ...@@ -10,7 +10,9 @@
#define DGL_ARRAY_H_ #define DGL_ARRAY_H_
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <algorithm>
#include <vector> #include <vector>
#include <utility>
namespace dgl { namespace dgl {
...@@ -23,22 +25,59 @@ typedef dgl::runtime::NDArray IntArray; ...@@ -23,22 +25,59 @@ typedef dgl::runtime::NDArray IntArray;
typedef dgl::runtime::NDArray FloatArray; typedef dgl::runtime::NDArray FloatArray;
typedef dgl::runtime::NDArray TypeArray; typedef dgl::runtime::NDArray TypeArray;
/*! \brief Create a new id array with given length (on CPU) */ namespace aten {
IdArray NewIdArray(int64_t length);
//////////////////////////////////////////////////////////////////////
// ID array
//////////////////////////////////////////////////////////////////////
/*!
* \brief Create a new id array with given length
* \param length The array length
* \param ctx The array context
* \param nbits The number of integer bits
* \return id array
*/
IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0},
uint8_t nbits = 64);
/*! /*!
* \brief Create a new boolean array with given length (on CPU) * \brief Create a new id array using the given vector data
* \note the elements are 64-bit. * \param vec The vector data
* \param nbits The integer bits of the returned array
* \param ctx The array context
* \return the id array
*/ */
BoolArray NewBoolArray(int64_t length); template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0});
/*! \brief Create a new id array with the given vector data (on CPU) */ /*!
IdArray VecToIdArray(const std::vector<dgl_id_t>& vec); * \brief Return an array representing a 1D range.
* \param low Lower bound (inclusive).
* \param high Higher bound (exclusive).
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return range array
*/
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
/*! \brief Create a copy of the given array */ /*!
* \brief Return an array full of the given value
* \param val The value to fill.
* \param length Number of elements.
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return the result array
*/
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
/*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr); IdArray Clone(IdArray arr);
/*! \brief Convert the idarray to the given bit width (on CPU) */ /*! \brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits); IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! \brief Arithmetic functions */ /*! \brief Arithmetic functions */
...@@ -57,30 +96,186 @@ IdArray Sub(dgl_id_t lhs, IdArray rhs); ...@@ -57,30 +96,186 @@ IdArray Sub(dgl_id_t lhs, IdArray rhs);
IdArray Mul(dgl_id_t lhs, IdArray rhs); IdArray Mul(dgl_id_t lhs, IdArray rhs);
IdArray Div(dgl_id_t lhs, IdArray rhs); IdArray Div(dgl_id_t lhs, IdArray rhs);
BoolArray LT(IdArray lhs, dgl_id_t rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */ /*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2); IdArray HStack(IdArray arr1, IdArray arr2);
/*! \brief Return the data under the index. In numpy notation, A[I] */
int64_t IndexSelect(IdArray array, int64_t index);
IdArray IndexSelect(IdArray array, IdArray index);
/*!
* \brief Relabel the given ids to consecutive ids.
*
* Relabeling is done inplace. The mapping is created from the union
* of the give arrays.
*
* \param arrays The id arrays to relabel.
* \return mapping array M from new id to old id.
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
/*! \brief Plain CSR matrix */ /*!
* \brief Plain CSR matrix
*
* The column indices are 0-based and are not necessarily sorted.
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
struct CSRMatrix { struct CSRMatrix {
IdArray indptr, indices, data; /*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols;
/*! \brief CSR index arrays */
runtime::NDArray indptr, indices;
/*! \brief data array, could be empty. */
runtime::NDArray data;
}; };
/*! \brief Plain COO structure */ /*!
* \brief Plain COO structure
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*
* We call a COO matrix is *coalesced* if its row index is sorted.
*/
struct COOMatrix { struct COOMatrix {
IdArray row, col, data; /*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols;
/*! \brief COO index arrays */
runtime::NDArray row, col;
/*! \brief data array, could be empty. */
runtime::NDArray data;
}; };
/*! \brief Slice rows of the given matrix and return. */ ///////////////////////// CSR routines //////////////////////////
CSRMatrix SliceRows(const CSRMatrix& csr, int64_t start, int64_t end);
/*! \brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row);
/*! \brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row);
/*! \brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row);
/*! \brief Convert COO matrix to CSR matrix. */ /* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
CSRMatrix ToCSR(const COOMatrix); runtime::NDArray CSRGetData(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRGetData.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);
/*!
* \brief Convert CSR matrix to COO matrix.
* \param csr Input csr matrix
* \param data_as_order If true, the data array in the input csr matrix contains the order
* by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it
* is essentially a consecutive range.
* \return a coo matrix
*/
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*!
* \brief Slice rows of the given matrix and return.
* \param csr CSR matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSliceRows(csr, 1, 3)
*
* num_rows = 2
* num_cols = 4
* indptr = [0, 1, 1]
* indices = [2]
*/
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param csr The input csr matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr);
///////////////////////// COO routines //////////////////////////
/*! \return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/*!
* \brief Convert COO matrix to CSR matrix.
*
* If the input COO matrix does not have data array, the data array of
* the result CSR matrix stores a shuffle index for how the entries
* will be reordered in CSR. The i^th entry in the result CSR corresponds
* to the CSR.data[i] th entry in the input COO.
*/
CSRMatrix COOToCSR(COOMatrix coo);
/*! \brief Convert COO matrix to CSR matrix. */ // inline implementations
COOMatrix ToCOO(const CSRMatrix); template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) {
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(ret->data));
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
return ret.CopyTo(ctx);
}
} // namespace aten
} // namespace dgl } // namespace dgl
#endif // DGL_ARRAY_H_ #endif // DGL_ARRAY_H_
...@@ -128,49 +128,13 @@ class GraphInterface { ...@@ -128,49 +128,13 @@ class GraphInterface {
} }
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
virtual BoolArray HasVertices(IdArray vids) const { virtual BoolArray HasVertices(IdArray vids) const = 0;
const auto len = vids->shape[0];
BoolArray rst = NewBoolArray(len);
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const uint64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0; virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/ /*! \return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const { virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
......
...@@ -69,11 +69,11 @@ class CSR : public GraphInterface { ...@@ -69,11 +69,11 @@ class CSR : public GraphInterface {
} }
DLContext Context() const override { DLContext Context() const override {
return indptr_->ctx; return adj_.indptr->ctx;
} }
uint8_t NumBits() const override { uint8_t NumBits() const override {
return indices_->dtype.bits; return adj_.indices->dtype.bits;
} }
bool IsMultigraph() const override; bool IsMultigraph() const override;
...@@ -83,15 +83,22 @@ class CSR : public GraphInterface { ...@@ -83,15 +83,22 @@ class CSR : public GraphInterface {
} }
uint64_t NumVertices() const override { uint64_t NumVertices() const override {
return indptr_->shape[0] - 1; return adj_.indptr->shape[0] - 1;
} }
uint64_t NumEdges() const override { uint64_t NumEdges() const override {
return indices_->shape[0]; return adj_.indices->shape[0];
}
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return {};
} }
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override; bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "CSR graph does not support efficient predecessor query." LOG(FATAL) << "CSR graph does not support efficient predecessor query."
<< " Please use successors on the reverse CSR graph."; << " Please use successors on the reverse CSR graph.";
...@@ -147,8 +154,7 @@ class CSR : public GraphInterface { ...@@ -147,8 +154,7 @@ class CSR : public GraphInterface {
} }
uint64_t OutDegree(dgl_id_t vid) const override { uint64_t OutDegree(dgl_id_t vid) const override {
const int64_t* indptr_data = static_cast<int64_t*>(indptr_->data); return aten::CSRGetRowNNZ(adj_, vid);
return indptr_data[vid + 1] - indptr_data[vid];
} }
DegreeArray OutDegrees(IdArray vids) const override; DegreeArray OutDegrees(IdArray vids) const override;
...@@ -165,21 +171,9 @@ class CSR : public GraphInterface { ...@@ -165,21 +171,9 @@ class CSR : public GraphInterface {
return Transpose(); return Transpose();
} }
DGLIdIters SuccVec(dgl_id_t vid) const override { DGLIdIters SuccVec(dgl_id_t vid) const override;
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end);
}
DGLIdIters OutEdgeVec(dgl_id_t vid) const override { DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end);
}
DGLIdIters PredVec(dgl_id_t vid) const override { DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient PredVec." LOG(FATAL) << "CSR graph does not support efficient PredVec."
...@@ -201,7 +195,7 @@ class CSR : public GraphInterface { ...@@ -201,7 +195,7 @@ class CSR : public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request."; CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {indptr_, indices_, edge_ids_}; return {adj_.indptr, adj_.indices, adj_.data};
} }
/*! \brief Indicate whether this uses shared memory. */ /*! \brief Indicate whether this uses shared memory. */
...@@ -220,8 +214,8 @@ class CSR : public GraphInterface { ...@@ -220,8 +214,8 @@ class CSR : public GraphInterface {
* \note The csr matrix shares the storage with this graph. * \note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids. * The data field of the CSR matrix stores the edge ids.
*/ */
CSRMatrix ToCSRMatrix() const { aten::CSRMatrix ToCSRMatrix() const {
return CSRMatrix{indptr_, indices_, edge_ids_}; return adj_;
} }
/*! /*!
...@@ -247,26 +241,19 @@ class CSR : public GraphInterface { ...@@ -247,26 +241,19 @@ class CSR : public GraphInterface {
// member getters // member getters
IdArray indptr() const { return indptr_; } IdArray indptr() const { return adj_.indptr; }
IdArray indices() const { return indices_; } IdArray indices() const { return adj_.indices; }
IdArray edge_ids() const { return edge_ids_; } IdArray edge_ids() const { return adj_.data; }
private: private:
/*! \brief prive default constructor */ /*! \brief prive default constructor */
CSR() {} CSR() {}
// The CSR arrays. // The internal CSR adjacency matrix.
// - The index is 0-based. // The data field stores edge ids.
// - The out edges of vertex v is stored from `indices_[indptr_[v]]` to aten::CSRMatrix adj_;
// `indices_[indptr_[v+1]]` (exclusive).
// - The indices are *not* necessarily sorted.
// TODO(minjie): in the future, we should separate CSR and graph. A general CSR
// is not necessarily a graph, but graph operations could be implemented by
// CSR matrix operations. CSR matrix operations would be backed by different
// devices (CPU, CUDA, ...), while graph interface will not be aware of that.
IdArray indptr_, indices_, edge_ids_;
// whether the graph is a multi-graph // whether the graph is a multi-graph
LazyObject<bool> is_multigraph_; LazyObject<bool> is_multigraph_;
...@@ -301,11 +288,11 @@ class COO : public GraphInterface { ...@@ -301,11 +288,11 @@ class COO : public GraphInterface {
} }
DLContext Context() const override { DLContext Context() const override {
return src_->ctx; return adj_.row->ctx;
} }
uint8_t NumBits() const override { uint8_t NumBits() const override {
return src_->dtype.bits; return adj_.row->dtype.bits;
} }
bool IsMultigraph() const override; bool IsMultigraph() const override;
...@@ -315,23 +302,34 @@ class COO : public GraphInterface { ...@@ -315,23 +302,34 @@ class COO : public GraphInterface {
} }
uint64_t NumVertices() const override { uint64_t NumVertices() const override {
return num_vertices_; return adj_.num_rows;
} }
uint64_t NumEdges() const override { uint64_t NumEdges() const override {
return src_->shape[0]; return adj_.row->shape[0];
} }
bool HasVertex(dgl_id_t vid) const override { bool HasVertex(dgl_id_t vid) const override {
return vid < NumVertices(); return vid < NumVertices();
} }
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph";
return {};
}
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween." LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
return false; return false;
} }
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override { IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Predecessors." LOG(FATAL) << "COO graph does not support efficient Predecessors."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
...@@ -356,12 +354,7 @@ class COO : public GraphInterface { ...@@ -356,12 +354,7 @@ class COO : public GraphInterface {
return {}; return {};
} }
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
return std::make_pair(src_data[eid], dst_data[eid]);
}
EdgeArray FindEdges(IdArray eids) const override; EdgeArray FindEdges(IdArray eids) const override;
...@@ -460,15 +453,15 @@ class COO : public GraphInterface { ...@@ -460,15 +453,15 @@ class COO : public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request."; CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) { if (transpose) {
return {HStack(dst_, src_)}; return {aten::HStack(adj_.col, adj_.row)};
} else { } else {
return {HStack(src_, dst_)}; return {aten::HStack(adj_.row, adj_.col)};
} }
} }
/*! \brief Return the transpose of this COO */ /*! \brief Return the transpose of this COO */
COOPtr Transpose() const { COOPtr Transpose() const {
return COOPtr(new COO(num_vertices_, dst_, src_)); return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
} }
/*! \brief Convert this COO to CSR */ /*! \brief Convert this COO to CSR */
...@@ -479,8 +472,8 @@ class COO : public GraphInterface { ...@@ -479,8 +472,8 @@ class COO : public GraphInterface {
* \note The coo matrix shares the storage with this graph. * \note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none. * The data field of the coo matrix is none.
*/ */
COOMatrix ToCOOMatrix() const { aten::COOMatrix ToCOOMatrix() const {
return COOMatrix{src_, dst_, {}}; return adj_;
} }
/*! /*!
...@@ -511,18 +504,18 @@ class COO : public GraphInterface { ...@@ -511,18 +504,18 @@ class COO : public GraphInterface {
// member getters // member getters
IdArray src() const { return src_; } IdArray src() const { return adj_.row; }
IdArray dst() const { return dst_; } IdArray dst() const { return adj_.col; }
private: private:
/* !\brief private default constructor */ /* !\brief private default constructor */
COO() {} COO() {}
/*! \brief number of vertices */ // The internal COO adjacency matrix.
int64_t num_vertices_; // The data field is empty
/*! \brief coordinate arrays */ aten::COOMatrix adj_;
IdArray src_, dst_;
/*! \brief whether the graph is a multi-graph */ /*! \brief whether the graph is a multi-graph */
LazyObject<bool> is_multigraph_; LazyObject<bool> is_multigraph_;
}; };
...@@ -635,6 +628,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -635,6 +628,8 @@ class ImmutableGraph: public GraphInterface {
return vid < NumVertices(); return vid < NumVertices();
} }
BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/ /*! \return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
if (in_csr_) { if (in_csr_) {
...@@ -644,6 +639,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -644,6 +639,14 @@ class ImmutableGraph: public GraphInterface {
} }
} }
BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {
if (in_csr_) {
return in_csr_->HasEdgesBetween(dst, src);
} else {
return GetOutCSR()->HasEdgesBetween(src, dst);
}
}
/*! /*!
* \brief Find the predecessors of a vertex. * \brief Find the predecessors of a vertex.
* \param vid The vertex id. * \param vid The vertex id.
...@@ -910,49 +913,13 @@ class ImmutableGraph: public GraphInterface { ...@@ -910,49 +913,13 @@ class ImmutableGraph: public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
/* !\brief Return in csr. If not exist, transpose the other one.*/ /* !\brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const { CSRPtr GetInCSR() const;
if (!in_csr_) {
if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
}
}
return in_csr_;
}
/* !\brief Return out csr. If not exist, transpose the other one.*/ /* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr GetOutCSR() const { CSRPtr GetOutCSR() const;
if (!out_csr_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
}
}
return out_csr_;
}
/* !\brief Return coo. If not exist, create from csr.*/ /* !\brief Return coo. If not exist, create from csr.*/
COOPtr GetCOO() const { COOPtr GetCOO() const;
if (!coo_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
} else {
CHECK(out_csr_) << "Both CSR are missing.";
const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
}
}
return coo_;
}
/*! /*!
* \brief Convert the given graph to an immutable graph. * \brief Convert the given graph to an immutable graph.
...@@ -1107,12 +1074,16 @@ template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter> ...@@ -1107,12 +1074,16 @@ template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges, CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin, IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph): is_multigraph_(is_multigraph) { bool is_multigraph): is_multigraph_(is_multigraph) {
indptr_ = NewIdArray(num_vertices + 1); // TODO(minjie): this should be changed to a device-agnostic implementation
indices_ = NewIdArray(num_edges); // in the future
edge_ids_ = NewIdArray(num_edges); adj_.num_rows = num_vertices;
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); adj_.num_cols = num_vertices;
dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data); adj_.indptr = aten::NewIdArray(num_vertices + 1);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data); adj_.indices = aten::NewIdArray(num_edges);
adj_.data = aten::NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
for (int64_t i = 0; i < num_vertices + 1; ++i) for (int64_t i = 0; i < num_vertices + 1; ++i)
*(indptr_data++) = *(indptr_begin++); *(indptr_data++) = *(indptr_begin++);
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
......
...@@ -292,16 +292,18 @@ struct NDArray::Container { ...@@ -292,16 +292,18 @@ struct NDArray::Container {
// the usages of functions are documented in place. // the usages of functions are documented in place.
inline NDArray::NDArray(Container* data) inline NDArray::NDArray(Container* data)
: data_(data) { : data_(data) {
if (data_)
data_->IncRef(); data_->IncRef();
} }
inline NDArray::NDArray(const NDArray& other) inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) { : data_(other.data_) {
if (data_)
data_->IncRef(); data_->IncRef();
} }
inline void NDArray::reset() { inline void NDArray::reset() {
if (data_ != nullptr) { if (data_) {
data_->DecRef(); data_->DecRef();
data_ = nullptr; data_ = nullptr;
} }
......
...@@ -85,7 +85,7 @@ class ObjectBase(object): ...@@ -85,7 +85,7 @@ class ObjectBase(object):
""" """
# assign handle first to avoid error raising # assign handle first to avoid error raising
self.handle = None self.handle = None
handle = __init_by_constructor__(fconstructor, args) handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle): if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle) handle = ObjectHandle(handle)
self.handle = handle self.handle = handle
......
...@@ -151,6 +151,7 @@ class GraphIndex(object): ...@@ -151,6 +151,7 @@ class GraphIndex(object):
readonly_state : bool readonly_state : bool
New readonly state of current graph index. New readonly state of current graph index.
""" """
# TODO(minjie): very ugly code, should fix this
n_nodes, multigraph, _, src, dst = self.__getstate__() n_nodes, multigraph, _, src, dst = self.__getstate__()
self.clear_cache() self.clear_cache()
state = (n_nodes, multigraph, readonly_state, src, dst) state = (n_nodes, multigraph, readonly_state, src, dst)
......
/*!
* Copyright (c) 2019 by Contributors
* \file array.cc
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
namespace dgl {
// TODO(minjie): currently these operators are only on CPU.
IdArray NewIdArray(int64_t length) {
return IdArray::Empty({length}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
BoolArray NewBoolArray(int64_t length) {
return BoolArray::Empty({length}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
IdArray VecToIdArray(const std::vector<dgl_id_t>& vec) {
IdArray ret = NewIdArray(vec.size());
std::copy(vec.begin(), vec.end(), static_cast<dgl_id_t*>(ret->data));
return ret;
}
IdArray Clone(IdArray arr) {
IdArray ret = NewIdArray(arr->shape[0]);
ret.CopyFrom(arr);
return ret;
}
IdArray AsNumBits(IdArray arr, uint8_t bits) {
if (arr->dtype.bits == bits) {
return arr;
} else {
const int64_t len = arr->shape[0];
IdArray ret = IdArray::Empty({len},
DLDataType{kDLInt, bits, 1}, DLContext{kDLCPU, 0});
if (arr->dtype.bits == 32 && bits == 64) {
const int32_t* arr_data = static_cast<int32_t*>(arr->data);
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else if (arr->dtype.bits == 64 && bits == 32) {
const int64_t* arr_data = static_cast<int64_t*>(arr->data);
int32_t* ret_data = static_cast<int32_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else {
LOG(FATAL) << "Invalid type conversion.";
}
return ret;
}
}
IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] + rhs_data[i];
}
return ret;
}
IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] - rhs_data[i];
}
return ret;
}
IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] * rhs_data[i];
}
return ret;
}
IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] / rhs_data[i];
}
return ret;
}
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] + rhs;
}
return ret;
}
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] - rhs;
}
return ret;
}
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] * rhs;
}
return ret;
}
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] / rhs;
}
return ret;
}
IdArray Add(dgl_id_t lhs, IdArray rhs) {
return Add(rhs, lhs);
}
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0]);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = lhs - rhs_data[i];
}
return ret;
}
IdArray Mul(dgl_id_t lhs, IdArray rhs) {
return Mul(rhs, lhs);
}
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0]);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = lhs / rhs_data[i];
}
return ret;
}
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L);
const dgl_id_t* arr1_data = static_cast<dgl_id_t*>(arr1->data);
const dgl_id_t* arr2_data = static_cast<dgl_id_t*>(arr2->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < L; ++i) {
ret_data[i] = arr1_data[i];
ret_data[i + L] = arr2_data[i];
}
return ret;
}
CSRMatrix SliceRows(const CSRMatrix& csr, int64_t start, int64_t end) {
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data);
const dgl_id_t* data = static_cast<dgl_id_t*>(csr.data->data);
const int64_t num_rows = end - start;
const int64_t nnz = indptr[end] - indptr[start];
CSRMatrix ret;
ret.indptr = NewIdArray(num_rows + 1);
ret.indices = NewIdArray(nnz);
ret.data = NewIdArray(nnz);
dgl_id_t* r_indptr = static_cast<dgl_id_t*>(ret.indptr->data);
dgl_id_t* r_indices = static_cast<dgl_id_t*>(ret.indices->data);
dgl_id_t* r_data = static_cast<dgl_id_t*>(ret.data->data);
for (int64_t i = start; i < end + 1; ++i) {
r_indptr[i - start] = indptr[i] - indptr[start];
}
std::copy(indices + indptr[start], indices + indptr[end], r_indices);
std::copy(data + indptr[start], data + indptr[end], r_data);
return ret;
}
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/arith.h
* \brief Arithmetic functors
*/
#ifndef DGL_ARRAY_ARITH_H_
#define DGL_ARRAY_ARITH_H_
namespace dgl {
namespace aten {
namespace arith {
struct Add {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 + t2;
}
};
struct Sub {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 - t2;
}
};
struct Mul {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 * t2;
}
};
struct Div {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 / t2;
}
};
struct LT {
template <typename T>
inline static bool Call(const T& t1, const T& t2) {
return t1 < t2;
}
};
} // namespace arith
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_ARITH_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/array.cc
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h"
#include "./common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx);
}
IdArray Clone(IdArray arr) {
IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);
ret.CopyFrom(arr);
return ret;
}
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
if (nbits == 32) {
ret = impl::Range<XPU, int32_t>(low, high, ctx);
} else if (nbits == 64) {
ret = impl::Range<XPU, int64_t>(low, high, ctx);
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
});
return ret;
}
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
if (nbits == 32) {
ret = impl::Full<XPU, int32_t>(val, length, ctx);
} else if (nbits == 64) {
ret = impl::Full<XPU, int64_t>(val, length, ctx);
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
});
return ret;
}
IdArray AsNumBits(IdArray arr, uint8_t bits) {
IdArray ret;
ATEN_XPU_SWITCH(arr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, {
ret = impl::AsNumBits<XPU, IdType>(arr, bits);
});
});
return ret;
}
IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(dgl_id_t lhs, IdArray rhs) {
return Add(rhs, lhs);
}
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(dgl_id_t lhs, IdArray rhs) {
return Mul(rhs, lhs);
}
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
BoolArray LT(IdArray lhs, dgl_id_t rhs) {
BoolArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::LT>(lhs, rhs);
});
});
return ret;
}
IdArray HStack(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::HStack<XPU, IdType>(lhs, rhs);
});
});
return ret;
}
IdArray IndexSelect(IdArray array, IdArray index) {
IdArray ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
});
});
return ret;
}
int64_t IndexSelect(IdArray array, int64_t index) {
int64_t ret = 0;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
});
});
return ret;
}
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray ret;
ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
ret = impl::Relabel_<XPU, IdType>(arrays);
});
});
return ret;
}
///////////////////////// CSR routines //////////////////////////
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
}
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
}
bool CSRHasDuplicate(CSRMatrix csr) {
bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
});
return ret;
}
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
int64_t ret = 0;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetRowData<XPU, IdType, DType>(csr, row);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, row, col);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetDataAndIndices<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
CSRMatrix CSRTranspose(CSRMatrix csr) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRTranspose<XPU, IdType, DType>(csr);
});
return ret;
}
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
COOMatrix ret;
if (data_as_order) {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
});
});
} else {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOO<XPU, IdType>(csr);
});
});
}
return ret;
}
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, start, end);
});
return ret;
}
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, rows);
});
return ret;
}
CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceMatrix<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
///////////////////////// COO routines //////////////////////////
bool COOHasDuplicate(COOMatrix coo) {
bool ret = false;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, {
ret = impl::COOHasDuplicate<XPU, IdType>(coo);
});
return ret;
}
CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, {
ret = impl::COOToCSR<XPU, IdType, DType>(coo);
});
return ret;
}
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/array_op.h
* \brief Array operator templates
*/
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_
#include <dgl/array.h>
#include <vector>
namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2);
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index);
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr);
// Convert CSR to COO using data array as order
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix COOToCSR(COOMatrix coo);
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_ARRAY_OP_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/common.h
* \brief Array operator common utilities
*/
#ifndef DGL_ARRAY_COMMON_H_
#define DGL_ARRAY_COMMON_H_
namespace dgl {
namespace aten {
#define ATEN_XPU_SWITCH(val, XPU, ...) \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
}
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can Only be int32 or int64"; \
}
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
}
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_COO_SWITCH(coo, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_COO_IDX_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_COMMON_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/array_op_impl.cc
* \brief Array operator CPU implementation
*/
#include <dgl/array.h>
#include <numeric>
#include "../arith.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
///////////////////////////// AsNumBits /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
if (sizeof(IdType) * 8 == bits) {
return arr;
}
const int64_t len = arr->shape[0];
IdArray ret = NewIdArray(len, arr->ctx, bits);
const IdType* arr_data = static_cast<IdType*>(arr->data);
if (bits == 32) {
int32_t* ret_data = static_cast<int32_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else {
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
}
return ret;
}
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits);
///////////////////////////// BinaryElewise /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs, rhs_data[i]);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
///////////////////////////// HStack /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L);
const IdType* arr1_data = static_cast<IdType*>(arr1->data);
const IdType* arr2_data = static_cast<IdType*>(arr2->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < L; ++i) {
ret_data[i] = arr1_data[i];
ret_data[i + L] = arr2_data[i];
}
return ret;
}
template IdArray HStack<kDLCPU, int32_t>(IdArray arr1, IdArray arr2);
template IdArray HStack<kDLCPU, int64_t>(IdArray arr1, IdArray arr2);
///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx) {
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
std::fill(ret_data, ret_data + length, val);
return ret;
}
template IdArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
///////////////////////////// Range /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
std::iota(ret_data, ret_data + high - low, low);
return ret;
}
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// IndexSelect /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index) {
const IdType* array_data = static_cast<IdType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
IdArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
CHECK_LT(idx_data[i], arr_len) << "Index out of range.";
ret_data[i] = array_data[idx_data[i]];
}
return ret;
}
template IdArray IndexSelect<kDLCPU, int32_t>(IdArray, IdArray);
template IdArray IndexSelect<kDLCPU, int64_t>(IdArray, IdArray);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index) {
const IdType* data = static_cast<IdType*>(array->data);
return data[index];
}
template int64_t IndexSelect<kDLCPU, int32_t>(IdArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(IdArray array, int64_t index);
///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build map & relabel
IdType newid = 0;
std::unordered_map<IdType, IdType> oldv2newv;
for (IdArray arr : arrays) {
for (int64_t i = 0; i < arr->shape[0]; ++i) {
const IdType id = static_cast<IdType*>(arr->data)[i];
if (!oldv2newv.count(id)) {
oldv2newv[id] = newid++;
}
static_cast<IdType*>(arr->data)[i] = oldv2newv[id];
}
}
// map array
IdArray maparr = NewIdArray(newid);
IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first;
}
return maparr;
}
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/spmat_op_impl.cc
* \brief Sparse matrix operator CPU implementation
*/
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
namespace {
/*!
* \brief A hashmap that maps each ids in the given array to new ids starting from zero.
*/
template <typename IdType>
class IdHashMap {
public:
// Construct the hashmap using the given id arrays.
// The id array could contain duplicates.
explicit IdHashMap(IdArray ids): filter_(kFilterSize, false) {
const IdType* ids_data = static_cast<IdType*>(ids->data);
const int64_t len = ids->shape[0];
IdType newid = 0;
for (int64_t i = 0; i < len; ++i) {
const IdType id = ids_data[i];
if (!Contains(id)) {
oldv2newv_[id] = newid++;
filter_[id & kFilterMask] = true;
}
}
}
// Return true if the given id is contained in this hashmap.
bool Contains(IdType id) const {
return filter_[id & kFilterMask] && oldv2newv_.count(id);
}
// Return the new id of the given id. If the given id is not contained
// in the hash map, returns the default_val instead.
IdType Map(IdType id, IdType default_val) const {
if (filter_[id & kFilterMask]) {
auto it = oldv2newv_.find(id);
return (it == oldv2newv_.end()) ? default_val : it->second;
} else {
return default_val;
}
}
private:
static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1;
// This bitmap is used as a bloom filter to remove some lookups.
// Hashtable is very slow. Using bloom filter can significantly speed up lookups.
std::vector<bool> filter_;
// The hashmap from old vid to new vid
std::unordered_map<IdType, IdType> oldv2newv_;
};
struct PairHash {
template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};
template <typename DType>
inline runtime::NDArray VecToNDArray(const std::vector<DType>& vec,
DLDataType dtype, DLContext ctx) {
const int64_t len = vec.size();
NDArray ret_arr = NDArray::Empty({len}, dtype, ctx);
DType* ptr = static_cast<DType*>(ret_arr->data);
std::copy(vec.begin(), vec.end(), ptr);
return ret_arr;
}
inline bool CSRHasData(CSRMatrix csr) {
return csr.data.defined();
}
inline bool COOHasData(COOMatrix csr) {
return csr.data.defined();
}
} // namespace
///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
if (indices_data[i] == col) {
return true;
}
}
return false;
}
template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
const auto rstlen = std::max(rowlen, collen);
NDArray rst = NDArray::Empty({rstlen}, row->dtype, row->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data);
const IdType* row_data = static_cast<IdType*>(row->data);
const IdType* col_data = static_cast<IdType*>(col->data);
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
*(rst_data++) = CSRIsNonZero<XPU, IdType>(csr, row_data[i], col_data[j])? 1 : 0;
}
return rst;
}
template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
for (IdType src = 0; src < csr.num_rows; ++src) {
std::unordered_set<IdType> hashmap;
for (IdType eid = indptr_data[src]; eid < indptr_data[src+1]; ++eid) {
const IdType dst = indices_data[eid];
if (hashmap.count(dst)) {
return true;
} else {
hashmap.insert(dst);
}
}
}
return false;
}
template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
return indptr_data[row + 1] - indptr_data[row];
}
template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
rst_data[i] = indptr_data[vid + 1] - indptr_data[vid];
}
return rst;
}
template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const int64_t offset = indptr_data[row] * sizeof(IdType);
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}
template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
CHECK(CSRHasData(csr)) << "missing data array";
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const int64_t offset = indptr_data[row] * sizeof(DType);
return csr.data.CreateView({len}, csr.data->dtype, offset);
}
template NDArray CSRGetRowData<kDLCPU, int32_t, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(CSRHasData(csr)) << "missing data array";
// TODO(minjie): use more efficient binary search when the column indices is sorted
CHECK(row >= 0 && row < csr.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
std::vector<DType> ret_vec;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data);
for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) {
if (indices_data[i] == col) {
ret_vec.push_back(data[i]);
}
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(CSRMatrix, int64_t, int64_t);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK(CSRHasData(csr)) << "missing data array";
// TODO(minjie): more efficient implementation for sorted column index
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = static_cast<IdType*>(rows->data);
const IdType* col_data = static_cast<IdType*>(cols->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data);
std::vector<DType> ret_vec;
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
if (indices_data[i] == col_id) {
ret_vec.push_back(data[i]);
}
}
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
}
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(CSRMatrix csr, NDArray rows, NDArray cols);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK(CSRHasData(csr)) << "missing data array";
// TODO(minjie): more efficient implementation for matrix without duplicate entries
// TODO(minjie): more efficient implementation for sorted column index
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const IdType* row_data = static_cast<IdType*>(rows->data);
const IdType* col_data = static_cast<IdType*>(cols->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data);
std::vector<IdType> ret_rows, ret_cols;
std::vector<DType> ret_data;
for (int64_t i = 0, j = 0; i < rowlen && j < collen; i += row_stride, j += col_stride) {
const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
if (indices_data[i] == col_id) {
ret_rows.push_back(row_id);
ret_cols.push_back(col_id);
ret_data.push_back(data[i]);
}
}
}
return {VecToIdArray(ret_rows, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToIdArray(ret_cols, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToNDArray(ret_data, csr.data->dtype, csr.data->ctx)};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRTranspose /////////////////////////////
// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
CHECK(CSRHasData(csr)) << "missing data array is currently not allowed in CSRTranspose.";
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
const int64_t nnz = csr.indices->shape[0];
const IdType* Ap = static_cast<IdType*>(csr.indptr->data);
const IdType* Aj = static_cast<IdType*>(csr.indices->data);
const DType* Ax = static_cast<DType*>(csr.data->data);
NDArray ret_indptr = NDArray::Empty({M + 1}, csr.indptr->dtype, csr.indptr->ctx);
NDArray ret_indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray ret_data = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
IdType* Bp = static_cast<IdType*>(ret_indptr->data);
IdType* Bi = static_cast<IdType*>(ret_indices->data);
DType* Bx = static_cast<DType*>(ret_data->data);
std::fill(Bp, Bp + M, 0);
for (int64_t j = 0; j < nnz; ++j) {
Bp[Aj[j]]++;
}
// cumsum
for (int64_t i = 0, cumsum = 0; i < M; ++i) {
const IdType temp = Bp[i];
Bp[i] = cumsum;
cumsum += temp;
}
Bp[M] = nnz;
for (int64_t i = 0; i < N; ++i) {
for (IdType j = Ap[i]; j < Ap[i+1]; ++j) {
const IdType dst = Aj[j];
Bi[Bp[dst]] = i;
Bx[Bp[dst]] = Ax[j];
Bp[dst]++;
}
}
// correct the indptr
for (int64_t i = 0, last = 0; i <= M; ++i) {
IdType temp = Bp[i];
Bp[i] = last;
last = temp;
}
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
}
template CSRMatrix CSRTranspose<kDLCPU, int32_t, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t, int64_t>(CSRMatrix csr);
///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
for (IdType i = 0; i < csr.indptr->shape[0] - 1; ++i) {
std::fill(ret_row_data + indptr_data[i],
ret_row_data + indptr_data[i + 1],
i);
}
return COOMatrix{csr.num_rows, csr.num_cols, ret_row, csr.indices, csr.data};
}
template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);
// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
CHECK(CSRHasData(csr)) << "missing data array.";
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
// data array should have the same type as the indices arrays
const IdType* data = static_cast<IdType*>(csr.data->data);
NDArray ret_row = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
NDArray ret_col = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
IdType* ret_row_data = static_cast<IdType*>(ret_row->data);
IdType* ret_col_data = static_cast<IdType*>(ret_col->data);
// scatter using the indices in the data array
for (IdType row = 0; row < N; ++row) {
for (IdType j = indptr_data[row]; j < indptr_data[row + 1]; ++j) {
const IdType col = indices_data[j];
ret_row_data[data[j]] = row;
ret_col_data[data[j]] = col;
}
}
COOMatrix coo;
coo.num_rows = N;
coo.num_cols = M;
coo.row = ret_row;
coo.col = ret_col;
// no data array
return coo;
}
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CHECK(CSRHasData(csr)) << "missing data array.";
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const int64_t num_rows = end - start;
const int64_t nnz = indptr[end] - indptr[start];
CSRMatrix ret;
ret.num_rows = num_rows;
ret.num_cols = csr.num_cols;
ret.indptr = NDArray::Empty({num_rows + 1}, csr.indptr->dtype, csr.indices->ctx);
ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
ret.data = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
IdType* r_indptr = static_cast<IdType*>(ret.indptr->data);
for (int64_t i = start; i < end + 1; ++i) {
r_indptr[i - start] = indptr[i] - indptr[start];
}
// indices and data can be view arrays
ret.indices = csr.indices.CreateView({nnz}, csr.indices->dtype, indptr[start] * sizeof(IdType));
ret.data = csr.data.CreateView({nnz}, csr.data->dtype, indptr[start] * sizeof(DType));
return ret;
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK(CSRHasData(csr)) << "missing data array.";
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data);
const auto len = rows->shape[0];
const IdType* rows_data = static_cast<IdType*>(rows->data);
int64_t nnz = 0;
for (int64_t i = 0; i < len; ++i) {
IdType vid = rows_data[i];
nnz += impl::CSRGetRowNNZ<XPU, IdType>(csr, vid);
}
CSRMatrix ret;
ret.num_rows = len;
ret.num_cols = csr.num_cols;
ret.indptr = NDArray::Empty({len + 1}, csr.indptr->dtype, csr.indices->ctx);
ret.indices = NDArray::Empty({nnz}, csr.indices->dtype, csr.indices->ctx);
ret.data = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
IdType* ret_indptr_data = static_cast<IdType*>(ret.indptr->data);
IdType* ret_indices_data = static_cast<IdType*>(ret.indices->data);
DType* ret_data = static_cast<DType*>(ret.data->data);
ret_indptr_data[0] = 0;
for (int64_t i = 0; i < len; ++i) {
const IdType rid = rows_data[i];
// note: zero is allowed
ret_indptr_data[i + 1] = ret_indptr_data[i] + indptr_data[rid + 1] - indptr_data[rid];
std::copy(indices_data + indptr_data[rid], indices_data + indptr_data[rid + 1],
ret_indices_data + ret_indptr_data[i]);
std::copy(data + indptr_data[rid], data + indptr_data[rid + 1],
ret_data + ret_indptr_data[i]);
}
return ret;
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
CHECK(CSRHasData(csr)) << "missing data array.";
IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0];
const int64_t new_ncols = cols->shape[0];
const IdType* rows_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data);
std::vector<IdType> sub_indptr, sub_indices;
std::vector<DType> sub_data;
sub_indptr.resize(new_nrows + 1, 0);
const IdType kInvalidId = new_ncols + 1;
for (int64_t i = 0; i < new_nrows; ++i) {
// NOTE: newi == i
const IdType oldi = rows_data[i];
CHECK(oldi >= 0 && oldi < csr.num_rows) << "Invalid row index: " << oldi;
for (IdType p = indptr_data[oldi]; p < indptr_data[oldi + 1]; ++p) {
const IdType oldj = indices_data[p];
const IdType newj = hashmap.Map(oldj, kInvalidId);
if (newj != kInvalidId) {
++sub_indptr[i];
sub_indices.push_back(newj);
sub_data.push_back(data[p]);
}
}
}
// cumsum sub_indptr
for (int64_t i = 0, cumsum = 0; i < new_nrows; ++i) {
const IdType temp = sub_indptr[i];
sub_indptr[i] = cumsum;
cumsum += temp;
}
sub_indptr[new_nrows] = sub_indices.size();
const int64_t nnz = sub_data.size();
NDArray sub_data_arr = NDArray::Empty({nnz}, csr.data->dtype, csr.data->ctx);
DType* ptr = static_cast<DType*>(sub_data_arr->data);
std::copy(sub_data.begin(), sub_data.end(), ptr);
return CSRMatrix{new_nrows, new_ncols,
VecToIdArray(sub_indptr, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToIdArray(sub_indices, csr.indptr->dtype.bits, csr.indptr->ctx),
sub_data_arr};
}
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// COOHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo) {
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
const IdType* src_data = static_cast<IdType*>(coo.row->data);
const IdType* dst_data = static_cast<IdType*>(coo.col->data);
const auto nnz = coo.row->shape[0];
for (IdType eid = 0; eid < nnz; ++eid) {
const auto& p = std::make_pair(src_data[eid], dst_data[eid]);
if (hashmap.count(p)) {
return true;
} else {
hashmap.insert(p);
}
}
return false;
}
template bool COOHasDuplicate<kDLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR /////////////////////////////
// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix COOToCSR(COOMatrix coo) {
const int64_t N = coo.num_rows;
const int64_t NNZ = coo.row->shape[0];
const IdType* row_data = static_cast<IdType*>(coo.row->data);
const IdType* col_data = static_cast<IdType*>(coo.col->data);
NDArray ret_indptr = NDArray::Empty({N + 1}, coo.row->dtype, coo.row->ctx);
NDArray ret_indices = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
NDArray ret_data;
if (COOHasData(coo)) {
ret_data = NDArray::Empty({NNZ}, coo.data->dtype, coo.data->ctx);
} else {
// if no data array in the input coo, the return data array is a shuffle index.
ret_data = NDArray::Empty({NNZ}, coo.row->dtype, coo.row->ctx);
}
IdType* Bp = static_cast<IdType*>(ret_indptr->data);
IdType* Bi = static_cast<IdType*>(ret_indices->data);
std::fill(Bp, Bp + N, 0);
for (int64_t i = 0; i < NNZ; ++i) {
Bp[row_data[i]]++;
}
// cumsum
for (int64_t i = 0, cumsum = 0; i < N; ++i) {
const IdType temp = Bp[i];
Bp[i] = cumsum;
cumsum += temp;
}
Bp[N] = NNZ;
for (int64_t i = 0; i < NNZ; ++i) {
const IdType r = row_data[i];
Bi[Bp[r]] = col_data[i];
if (COOHasData(coo)) {
const DType* data = static_cast<DType*>(coo.data->data);
DType* Bx = static_cast<DType*>(ret_data->data);
Bx[Bp[r]] = data[i];
} else {
IdType* Bx = static_cast<IdType*>(ret_data->data);
Bx[Bp[r]] = i;
}
Bp[r]++;
}
// correct the indptr
for (int64_t i = 0, last = 0; i <= N; ++i) {
IdType temp = Bp[i];
Bp[i] = last;
last = temp;
}
return CSRMatrix{coo.num_rows, coo.num_cols, ret_indptr, ret_indices, ret_data};
}
template CSRMatrix COOToCSR<kDLCPU, int32_t, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLCPU, int64_t, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten
} // namespace dgl
...@@ -12,6 +12,16 @@ ...@@ -12,6 +12,16 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLDataType& ty) {
return os << "code=" << ty.code << ",bits=" << ty.bits << "lanes=" << ty.lanes;
}
/*! \brief Check whether two device contexts are the same.*/ /*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) { inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id; return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
...@@ -19,7 +29,7 @@ inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) { ...@@ -19,7 +29,7 @@ inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
/*! \brief Output the string representation of device context.*/ /*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) { inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << "" << ctx.device_type << ":" << ctx.device_id << ""; return os << ctx.device_type << ":" << ctx.device_id;
} }
namespace dgl { namespace dgl {
...@@ -45,8 +55,7 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc( ...@@ -45,8 +55,7 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
/*!\brief Return whether the array is a valid 1D int array*/ /*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1 return arr->ndim == 1 && arr->dtype.code == kDLInt;
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
} }
/*! /*!
......
...@@ -138,9 +138,9 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs ...@@ -138,9 +138,9 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs
num_nodes += gr->NumVertices(); num_nodes += gr->NumVertices();
num_edges += gr->NumEdges(); num_edges += gr->NumEdges();
} }
IdArray indptr_arr = NewIdArray(num_nodes + 1); IdArray indptr_arr = aten::NewIdArray(num_nodes + 1);
IdArray indices_arr = NewIdArray(num_edges); IdArray indices_arr = aten::NewIdArray(num_edges);
IdArray edge_ids_arr = NewIdArray(num_edges); IdArray edge_ids_arr = aten::NewIdArray(num_edges);
dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data); dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data); dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data); dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
...@@ -207,9 +207,9 @@ std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGra ...@@ -207,9 +207,9 @@ std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGra
const int64_t end_pos = cumsum[i + 1]; const int64_t end_pos = cumsum[i + 1];
const int64_t g_num_nodes = sizes_data[i]; const int64_t g_num_nodes = sizes_data[i];
const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos]; const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos];
IdArray indptr_arr = NewIdArray(g_num_nodes + 1); IdArray indptr_arr = aten::NewIdArray(g_num_nodes + 1);
IdArray indices_arr = NewIdArray(g_num_edges); IdArray indices_arr = aten::NewIdArray(g_num_edges);
IdArray edge_ids_arr = NewIdArray(g_num_edges); IdArray edge_ids_arr = aten::NewIdArray(g_num_edges);
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data); dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data); dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data); dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
...@@ -329,13 +329,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) { ...@@ -329,13 +329,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
for (dgl_id_t v = u; v < g->NumVertices(); ++v) { for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]); const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
if (new_n_e > 0) { if (new_n_e > 0) {
IdArray us = NewIdArray(new_n_e); IdArray us = aten::NewIdArray(new_n_e);
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data); dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
std::fill(us_data, us_data + new_n_e, u); std::fill(us_data, us_data + new_n_e, u);
if (u == v) { if (u == v) {
bg.AddEdges(us, us); bg.AddEdges(us, us);
} else { } else {
IdArray vs = NewIdArray(new_n_e); IdArray vs = aten::NewIdArray(new_n_e);
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data); dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
std::fill(vs_data, vs_data + new_n_e, v); std::fill(vs_data, vs_data + new_n_e, v);
bg.AddEdges(us, vs); bg.AddEdges(us, vs);
...@@ -380,8 +380,8 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) { ...@@ -380,8 +380,8 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
} }
} }
IdArray srcs_array = VecToIdArray(srcs); IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = VecToIdArray(dsts); IdArray dsts_array = aten::VecToIdArray(dsts);
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph())); COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()));
return ImmutableGraph(coo); return ImmutableGraph(coo);
} }
......
...@@ -14,59 +14,6 @@ ...@@ -14,59 +14,6 @@
namespace dgl { namespace dgl {
namespace { namespace {
/*!
* \brief A hashmap that maps each ids in the given array to new ids starting from zero.
*/
class IdHashMap {
public:
// Construct the hashmap using the given id arrays.
// The id array could contain duplicates.
explicit IdHashMap(IdArray ids): filter_(kFilterSize, false) {
const dgl_id_t* ids_data = static_cast<dgl_id_t*>(ids->data);
const int64_t len = ids->shape[0];
dgl_id_t newid = 0;
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t id = ids_data[i];
if (!Contains(id)) {
oldv2newv_[id] = newid++;
filter_[id & kFilterMask] = true;
}
}
}
// Return true if the given id is contained in this hashmap.
bool Contains(dgl_id_t id) const {
return filter_[id & kFilterMask] && oldv2newv_.count(id);
}
// Return the new id of the given id. If the given id is not contained
// in the hash map, returns the default_val instead.
dgl_id_t Map(dgl_id_t id, dgl_id_t default_val) const {
if (filter_[id & kFilterMask]) {
auto it = oldv2newv_.find(id);
return (it == oldv2newv_.end()) ? default_val : it->second;
} else {
return default_val;
}
}
private:
static constexpr int32_t kFilterMask = 0xFFFFFF;
static constexpr int32_t kFilterSize = kFilterMask + 1;
// This bitmap is used as a bloom filter to remove some lookups.
// Hashtable is very slow. Using bloom filter can significantly speed up lookups.
std::vector<bool> filter_;
// The hashmap from old vid to new vid
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv_;
};
struct PairHash {
template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32 #ifndef _WIN32
...@@ -97,26 +44,30 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( ...@@ -97,26 +44,30 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph) CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
: is_multigraph_(is_multigraph) { : is_multigraph_(is_multigraph) {
indptr_ = NewIdArray(num_vertices + 1); CHECK(!(num_vertices == 0 && num_edges != 0));
indices_ = NewIdArray(num_edges); adj_ = aten::CSRMatrix{num_vertices, num_vertices,
edge_ids_ = NewIdArray(num_edges); aten::NewIdArray(num_vertices + 1),
aten::NewIdArray(num_edges),
aten::NewIdArray(num_edges)};
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
: indptr_(indptr), indices_(indices), edge_ids_(edge_ids) {
CHECK(IsValidIdArray(indptr)); CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: indptr_(indptr), indices_(indices), edge_ids_(edge_ids), : is_multigraph_(is_multigraph) {
is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(indptr)); CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
...@@ -127,12 +78,14 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -127,12 +78,14 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, true); shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays // copy the given data into the shared memory arrays
indptr_.CopyFrom(indptr); adj_.indptr.CopyFrom(indptr);
indices_.CopyFrom(indices); adj_.indices.CopyFrom(indices);
edge_ids_.CopyFrom(edge_ids); adj_.data.CopyFrom(edge_ids);
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
...@@ -144,328 +97,118 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, ...@@ -144,328 +97,118 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, true); shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays // copy the given data into the shared memory arrays
indptr_.CopyFrom(indptr); adj_.indptr.CopyFrom(indptr);
indices_.CopyFrom(indices); adj_.indices.CopyFrom(indices);
edge_ids_.CopyFrom(edge_ids); adj_.data.CopyFrom(edge_ids);
} }
CSR::CSR(const std::string &shared_mem_name, CSR::CSR(const std::string &shared_mem_name,
int64_t num_verts, int64_t num_edges, bool is_multigraph) int64_t num_verts, int64_t num_edges, bool is_multigraph)
: is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) { : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( CHECK(!(num_verts == 0 && num_edges != 0));
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, false); shared_mem_name, num_verts, num_edges, false);
} }
bool CSR::IsMultigraph() const { bool CSR::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag. // The lambda will be called the first time to initialize the is_multigraph flag.
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () { return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); return aten::CSRHasDuplicate(adj_);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
for (dgl_id_t src = 0; src < NumVertices(); ++src) {
std::unordered_set<dgl_id_t> hashmap;
for (dgl_id_t eid = indptr_data[src]; eid < indptr_data[src+1]; ++eid) {
const dgl_id_t dst = indices_data[eid];
if (hashmap.count(dst)) {
return true;
} else {
hashmap.insert(dst);
}
}
}
return false;
}); });
} }
CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const { CSR::EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data); IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
const dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data); IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
const dgl_id_t off = indptr_data[vid]; return CSR::EdgeArray{ret_src, ret_dst, ret_eid};
const int64_t len = OutDegree(vid);
IdArray src = NewIdArray(len);
IdArray dst = NewIdArray(len);
IdArray eid = NewIdArray(len);
dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);
dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);
dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);
std::fill(src_data, src_data + len, vid);
std::copy(indices_data + off, indices_data + off + len, dst_data);
std::copy(edge_ids_data + off, edge_ids_data + off + len, eid_data);
return CSR::EdgeArray{src, dst, eid};
} }
CSR::EdgeArray CSR::OutEdges(IdArray vids) const { CSR::EdgeArray CSR::OutEdges(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); auto csrsubmat = aten::CSRSliceRows(adj_, vids);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data); auto coosubmat = aten::CSRToCOO(csrsubmat, false);
const dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data); // Note that the row id in the csr submat is relabled, so
const auto len = vids->shape[0]; // we need to recover it using an index select.
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data); auto row = aten::IndexSelect(vids, coosubmat.row);
int64_t rstlen = 0; return CSR::EdgeArray{row, coosubmat.col, coosubmat.data};
for (int64_t i = 0; i < len; ++i) {
dgl_id_t vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rstlen += OutDegree(vid);
}
IdArray src = NewIdArray(rstlen);
IdArray dst = NewIdArray(rstlen);
IdArray eid = NewIdArray(rstlen);
dgl_id_t* src_ptr = static_cast<dgl_id_t*>(src->data);
dgl_id_t* dst_ptr = static_cast<dgl_id_t*>(dst->data);
dgl_id_t* eid_ptr = static_cast<dgl_id_t*>(eid->data);
for (int64_t i = 0; i < len; ++i) {
const dgl_id_t vid = vid_data[i];
const dgl_id_t off = indptr_data[vid];
const int64_t deg = OutDegree(vid);
if (deg == 0)
continue;
const auto *succ = indices_data + off;
const auto *eids = edge_ids_data + off;
for (int64_t j = 0; j < deg; ++j) {
*(src_ptr++) = vid;
*(dst_ptr++) = succ[j];
*(eid_ptr++) = eids[j];
}
}
return CSR::EdgeArray{src, dst, eid};
} }
DegreeArray CSR::OutDegrees(IdArray vids) const { DegreeArray CSR::OutDegrees(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto len = vids->shape[0]; return aten::CSRGetRowNNZ(adj_, vids);
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data);
DegreeArray rst = DegreeArray::Empty({len}, vids->dtype, vids->ctx);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) {
const auto vid = vid_data[i];
CHECK(HasVertex(vid)) << "Invalid vertex: " << vid;
rst_data[i] = OutDegree(vid);
}
return rst;
} }
bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const { bool CSR::HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src)) << "Invalid vertex id: " << src; CHECK(HasVertex(src)) << "Invalid vertex id: " << src;
CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst; CHECK(HasVertex(dst)) << "Invalid vertex id: " << dst;
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); return aten::CSRIsNonZero(adj_, src, dst);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data); }
for (dgl_id_t i = indptr_data[src]; i < indptr_data[src+1]; ++i) {
if (indices_data[i] == dst) { BoolArray CSR::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
return true; CHECK(IsValidIdArray(src_ids)) << "Invalid vertex id array.";
} CHECK(IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
} return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
return false;
} }
IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const { IdArray CSR::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
CHECK(radius == 1) << "invalid radius: " << radius; CHECK(radius == 1) << "invalid radius: " << radius;
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); return aten::CSRGetRowColumnIndices(adj_, vid);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const int64_t len = indptr_data[vid + 1] - indptr_data[vid];
IdArray rst = NewIdArray(len);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
std::copy(indices_data + indptr_data[vid],
indices_data + indptr_data[vid + 1],
rst_data);
return rst;
} }
IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const { IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
// TODO(minjie): use more efficient binary search when the column indices
// are also sorted.
CHECK(HasVertex(src)) << "invalid vertex: " << src; CHECK(HasVertex(src)) << "invalid vertex: " << src;
CHECK(HasVertex(dst)) << "invalid vertex: " << dst; CHECK(HasVertex(dst)) << "invalid vertex: " << dst;
std::vector<dgl_id_t> ids; return aten::CSRGetData(adj_, src, dst);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
for (dgl_id_t i = indptr_data[src]; i < indptr_data[src+1]; ++i) {
if (indices_data[i] == dst) {
ids.push_back(eid_data[i]);
}
}
return VecToIdArray(ids);
} }
CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const { CSR::EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
// TODO(minjie): more efficient implementation for simple graph const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
CHECK(IsValidIdArray(src_ids)) << "Invalid src id array."; return CSR::EdgeArray{arrs[0], arrs[1], arrs[2]};
CHECK(IsValidIdArray(dst_ids)) << "Invalid dst id array.";
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))
<< "Invalid src and dst id array.";
const int src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;
const int dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
std::vector<dgl_id_t> src, dst, eid;
for (int64_t i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) {
const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
CHECK(HasVertex(src_id) && HasVertex(dst_id)) <<
"invalid edge: " << src_id << " -> " << dst_id;
for (dgl_id_t i = indptr_data[src_id]; i < indptr_data[src_id+1]; ++i) {
if (indices_data[i] == dst_id) {
src.push_back(src_id);
dst.push_back(dst_id);
eid.push_back(eid_data[i]);
}
}
}
return CSR::EdgeArray{VecToIdArray(src), VecToIdArray(dst), VecToIdArray(eid)};
} }
CSR::EdgeArray CSR::Edges(const std::string &order) const { CSR::EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst")) CHECK(order.empty() || order == std::string("srcdst"))
<< "COO only support Edges of order \"srcdst\"," << "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\"."; << " but got \"" << order << "\".";
const int64_t rstlen = NumEdges(); const auto& coo = aten::CSRToCOO(adj_, false);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); return CSR::EdgeArray{coo.row, coo.col, coo.data};
IdArray rst_src = NewIdArray(rstlen);
dgl_id_t* rst_src_data = static_cast<dgl_id_t*>(rst_src->data);
// If sorted, the returned edges are sorted by the source Id and dest Id.
for (dgl_id_t src = 0; src < NumVertices(); ++src) {
std::fill(rst_src_data + indptr_data[src],
rst_src_data + indptr_data[src + 1],
src);
}
return CSR::EdgeArray{rst_src, indices_, edge_ids_};
} }
Subgraph CSR::VertexSubgraph(IdArray vids) const { Subgraph CSR::VertexSubgraph(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
IdHashMap hashmap(vids); const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
const int64_t len = vids->shape[0]; CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
return Subgraph{subcsr, vids, submat.data};
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
std::vector<dgl_id_t> sub_indptr, sub_indices, sub_eids, induced_edges;
sub_indptr.resize(len + 1, 0);
const dgl_id_t kInvalidId = len + 1;
for (int64_t i = 0; i < len; ++i) {
// NOTE: newv == i
const dgl_id_t oldv = vid_data[i];
CHECK(HasVertex(oldv)) << "Invalid vertex: " << oldv;
for (dgl_id_t olde = indptr_data[oldv]; olde < indptr_data[oldv+1]; ++olde) {
const dgl_id_t oldu = indices_data[olde];
const dgl_id_t newu = hashmap.Map(oldu, kInvalidId);
if (newu != kInvalidId) {
++sub_indptr[i];
sub_indices.push_back(newu);
induced_edges.push_back(eid_data[olde]);
}
}
}
sub_eids.resize(sub_indices.size());
std::iota(sub_eids.begin(), sub_eids.end(), 0);
// cumsum sub_indptr
for (int64_t i = 0, cumsum = 0; i < len; ++i) {
const dgl_id_t temp = sub_indptr[i];
sub_indptr[i] = cumsum;
cumsum += temp;
}
sub_indptr[len] = sub_indices.size();
CSRPtr subcsr(new CSR(
VecToIdArray(sub_indptr), VecToIdArray(sub_indices), VecToIdArray(sub_eids)));
return Subgraph{subcsr, vids, VecToIdArray(induced_edges)};
} }
// complexity: time O(E + V), space O(1)
CSRPtr CSR::Transpose() const { CSRPtr CSR::Transpose() const {
const int64_t N = NumVertices(); const auto& trans = aten::CSRTranspose(adj_);
const int64_t M = NumEdges(); return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
const dgl_id_t* Ap = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* Aj = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t* Ax = static_cast<dgl_id_t*>(edge_ids_->data);
IdArray ret_indptr = NewIdArray(N + 1);
IdArray ret_indices = NewIdArray(M);
IdArray ret_edge_ids = NewIdArray(M);
dgl_id_t* Bp = static_cast<dgl_id_t*>(ret_indptr->data);
dgl_id_t* Bi = static_cast<dgl_id_t*>(ret_indices->data);
dgl_id_t* Bx = static_cast<dgl_id_t*>(ret_edge_ids->data);
std::fill(Bp, Bp + N, 0);
for (int64_t j = 0; j < M; ++j) {
Bp[Aj[j]]++;
}
// cumsum
for (int64_t i = 0, cumsum = 0; i < N; ++i) {
const dgl_id_t temp = Bp[i];
Bp[i] = cumsum;
cumsum += temp;
}
Bp[N] = M;
for (int64_t i = 0; i < N; ++i) {
for (dgl_id_t j = Ap[i]; j < Ap[i+1]; ++j) {
const dgl_id_t dst = Aj[j];
Bi[Bp[dst]] = i;
Bx[Bp[dst]] = Ax[j];
Bp[dst]++;
}
}
// correct the indptr
for (int64_t i = 0, last = 0; i <= N; ++i) {
dgl_id_t temp = Bp[i];
Bp[i] = last;
last = temp;
}
return CSRPtr(new CSR(ret_indptr, ret_indices, ret_edge_ids));
} }
// complexity: time O(E + V), space O(1)
COOPtr CSR::ToCOO() const { COOPtr CSR::ToCOO() const {
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data); const auto& coo = aten::CSRToCOO(adj_, true);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data); return COOPtr(new COO(NumVertices(), coo.row, coo.col));
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
IdArray ret_src = NewIdArray(NumEdges());
IdArray ret_dst = NewIdArray(NumEdges());
dgl_id_t* ret_src_data = static_cast<dgl_id_t*>(ret_src->data);
dgl_id_t* ret_dst_data = static_cast<dgl_id_t*>(ret_dst->data);
// scatter by edge id
for (dgl_id_t src = 0; src < NumVertices(); ++src) {
for (dgl_id_t eid = indptr_data[src]; eid < indptr_data[src + 1]; ++eid) {
const dgl_id_t dst = indices_data[eid];
ret_src_data[eid_data[eid]] = src;
ret_dst_data[eid_data[eid]] = dst;
}
}
return COOPtr(new COO(NumVertices(), ret_src, ret_dst));
} }
CSR CSR::CopyTo(const DLContext& ctx) const { CSR CSR::CopyTo(const DLContext& ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
// TODO(minjie): change to use constructor later CSR ret(adj_.indptr.CopyTo(ctx),
CSR ret; adj_.indices.CopyTo(ctx),
ret.indptr_ = indptr_.CopyTo(ctx); adj_.data.CopyTo(ctx));
ret.indices_ = indices_.CopyTo(ctx);
ret.edge_ids_ = edge_ids_.CopyTo(ctx);
ret.is_multigraph_ = is_multigraph_; ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
...@@ -476,7 +219,7 @@ CSR CSR::CopyToSharedMem(const std::string &name) const { ...@@ -476,7 +219,7 @@ CSR CSR::CopyToSharedMem(const std::string &name) const {
CHECK(name == shared_mem_name_); CHECK(name == shared_mem_name_);
return *this; return *this;
} else { } else {
return CSR(indptr_, indices_, edge_ids_, name); return CSR(adj_.indptr, adj_.indices, adj_.data, name);
} }
} }
...@@ -484,192 +227,113 @@ CSR CSR::AsNumBits(uint8_t bits) const { ...@@ -484,192 +227,113 @@ CSR CSR::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
} else { } else {
// TODO(minjie): change to use constructor later CSR ret(aten::AsNumBits(adj_.indptr, bits),
CSR ret; aten::AsNumBits(adj_.indices, bits),
ret.indptr_ = dgl::AsNumBits(indptr_, bits); aten::AsNumBits(adj_.data, bits));
ret.indices_ = dgl::AsNumBits(indices_, bits);
ret.edge_ids_ = dgl::AsNumBits(edge_ids_, bits);
ret.is_multigraph_ = is_multigraph_; ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end);
}
DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end);
}
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
// //
// COO graph implementation // COO graph implementation
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
COO::COO(int64_t num_vertices, IdArray src, IdArray dst) COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
: num_vertices_(num_vertices), src_(src), dst_(dst) {
CHECK(IsValidIdArray(src)); CHECK(IsValidIdArray(src));
CHECK(IsValidIdArray(dst)); CHECK(IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]); CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
} }
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph) COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
: num_vertices_(num_vertices), src_(src), dst_(dst), is_multigraph_(is_multigraph) { : is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(src)); CHECK(IsValidIdArray(src));
CHECK(IsValidIdArray(dst)); CHECK(IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]); CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
} }
bool COO::IsMultigraph() const { bool COO::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag. // The lambda will be called the first time to initialize the is_multigraph flag.
return const_cast<COO*>(this)->is_multigraph_.Get([this] () { return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
std::unordered_set<std::pair<dgl_id_t, dgl_id_t>, PairHash> hashmap; return aten::COOHasDuplicate(adj_);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
for (dgl_id_t eid = 0; eid < NumEdges(); ++eid) {
const auto& p = std::make_pair(src_data[eid], dst_data[eid]);
if (hashmap.count(p)) {
return true;
} else {
hashmap.insert(p);
}
}
return false;
}); });
} }
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
const auto src = aten::IndexSelect(adj_.row, eid);
const auto dst = aten::IndexSelect(adj_.col, eid);
return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}
COO::EdgeArray COO::FindEdges(IdArray eids) const { COO::EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
dgl_id_t* eid_data = static_cast<dgl_id_t*>(eids->data); return EdgeArray{aten::IndexSelect(adj_.row, eids),
int64_t len = eids->shape[0]; aten::IndexSelect(adj_.col, eids),
IdArray rst_src = NewIdArray(len); eids};
IdArray rst_dst = NewIdArray(len);
dgl_id_t* rst_src_data = static_cast<dgl_id_t*>(rst_src->data);
dgl_id_t* rst_dst_data = static_cast<dgl_id_t*>(rst_dst->data);
for (int64_t i = 0; i < len; i++) {
auto edge = COO::FindEdge(eid_data[i]);
rst_src_data[i] = edge.first;
rst_dst_data[i] = edge.second;
}
return COO::EdgeArray{rst_src, rst_dst, eids};
} }
COO::EdgeArray COO::Edges(const std::string &order) const { COO::EdgeArray COO::Edges(const std::string &order) const {
const int64_t rstlen = NumEdges();
CHECK(order.empty() || order == std::string("eid")) CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \"" << "COO only support Edges of order \"eid\", but got \""
<< order << "\"."; << order << "\".";
IdArray rst_eid = NewIdArray(rstlen); IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
dgl_id_t* rst_eid_data = static_cast<dgl_id_t*>(rst_eid->data); return EdgeArray{adj_.row, adj_.col, rst_eid};
std::iota(rst_eid_data, rst_eid_data + rstlen, 0);
return EdgeArray{src_, dst_, rst_eid};
} }
Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array."; CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
const dgl_id_t* eids_data = static_cast<dgl_id_t*>(eids->data);
IdArray new_src = NewIdArray(eids->shape[0]);
IdArray new_dst = NewIdArray(eids->shape[0]);
dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data);
dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data);
if (!preserve_nodes) { if (!preserve_nodes) {
dgl_id_t newid = 0; IdArray new_src = aten::IndexSelect(adj_.row, eids);
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv; IdArray new_dst = aten::IndexSelect(adj_.col, eids);
IdArray induced_nodes = aten::Relabel_({new_src, new_dst});
for (int64_t i = 0; i < eids->shape[0]; ++i) { const auto new_nnodes = induced_nodes->shape[0];
const dgl_id_t eid = eids_data[i]; COOPtr subcoo(new COO(new_nnodes, new_src, new_dst));
const dgl_id_t src = src_data[eid];
const dgl_id_t dst = dst_data[eid];
if (!oldv2newv.count(src)) {
oldv2newv[src] = newid++;
}
if (!oldv2newv.count(dst)) {
oldv2newv[dst] = newid++;
}
*(new_src_data++) = oldv2newv[src];
*(new_dst_data++) = oldv2newv[dst];
}
// induced nodes
IdArray induced_nodes = NewIdArray(newid);
dgl_id_t* induced_nodes_data = static_cast<dgl_id_t*>(induced_nodes->data);
for (const auto& kv : oldv2newv) {
induced_nodes_data[kv.second] = kv.first;
}
COOPtr subcoo(new COO(newid, new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids}; return Subgraph{subcoo, induced_nodes, eids};
} else { } else {
for (int64_t i = 0; i < eids->shape[0]; ++i) { IdArray new_src = aten::IndexSelect(adj_.row, eids);
const dgl_id_t eid = eids_data[i]; IdArray new_dst = aten::IndexSelect(adj_.col, eids);
const dgl_id_t src = src_data[eid]; IdArray induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
const dgl_id_t dst = dst_data[eid];
*(new_src_data++) = src;
*(new_dst_data++) = dst;
}
IdArray induced_nodes = NewIdArray(NumVertices());
dgl_id_t* induced_nodes_data = static_cast<dgl_id_t*>(induced_nodes->data);
for (int64_t i = 0; i < NumVertices(); ++i)
*(induced_nodes_data++) = i;
COOPtr subcoo(new COO(NumVertices(), new_src, new_dst)); COOPtr subcoo(new COO(NumVertices(), new_src, new_dst));
return Subgraph{subcoo, induced_nodes, eids}; return Subgraph{subcoo, induced_nodes, eids};
} }
} }
// complexity: time O(E + V), space O(1)
CSRPtr COO::ToCSR() const { CSRPtr COO::ToCSR() const {
const int64_t N = num_vertices_; const auto& csr = aten::COOToCSR(adj_);
const int64_t M = src_->shape[0]; return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
IdArray indptr = NewIdArray(N + 1);
IdArray indices = NewIdArray(M);
IdArray edge_ids = NewIdArray(M);
dgl_id_t* Bp = static_cast<dgl_id_t*>(indptr->data);
dgl_id_t* Bi = static_cast<dgl_id_t*>(indices->data);
dgl_id_t* Bx = static_cast<dgl_id_t*>(edge_ids->data);
std::fill(Bp, Bp + N, 0);
for (int64_t i = 0; i < M; ++i) {
Bp[src_data[i]]++;
}
// cumsum
for (int64_t i = 0, cumsum = 0; i < N; ++i) {
const dgl_id_t temp = Bp[i];
Bp[i] = cumsum;
cumsum += temp;
}
Bp[N] = M;
for (int64_t i = 0; i < M; ++i) {
const dgl_id_t src = src_data[i];
const dgl_id_t dst = dst_data[i];
Bi[Bp[src]] = dst;
Bx[Bp[src]] = i;
Bp[src]++;
}
// correct the indptr
for (int64_t i = 0, last = 0; i <= N; ++i) {
dgl_id_t temp = Bp[i];
Bp[i] = last;
last = temp;
}
return CSRPtr(new CSR(indptr, indices, edge_ids));
} }
COO COO::CopyTo(const DLContext& ctx) const { COO COO::CopyTo(const DLContext& ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
// TODO(minjie): change to use constructor later COO ret(NumVertices(),
COO ret; adj_.row.CopyTo(ctx),
ret.num_vertices_ = num_vertices_; adj_.col.CopyTo(ctx));
ret.src_ = src_.CopyTo(ctx);
ret.dst_ = dst_.CopyTo(ctx);
ret.is_multigraph_ = is_multigraph_; ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
...@@ -677,17 +341,16 @@ COO COO::CopyTo(const DLContext& ctx) const { ...@@ -677,17 +341,16 @@ COO COO::CopyTo(const DLContext& ctx) const {
COO COO::CopyToSharedMem(const std::string &name) const { COO COO::CopyToSharedMem(const std::string &name) const {
LOG(FATAL) << "COO doesn't supprt shared memory yet"; LOG(FATAL) << "COO doesn't supprt shared memory yet";
return COO();
} }
COO COO::AsNumBits(uint8_t bits) const { COO COO::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
} else { } else {
// TODO(minjie): change to use constructor later COO ret(NumVertices(),
COO ret; aten::AsNumBits(adj_.row, bits),
ret.num_vertices_ = num_vertices_; aten::AsNumBits(adj_.col, bits));
ret.src_ = dgl::AsNumBits(src_, bits);
ret.dst_ = dgl::AsNumBits(dst_, bits);
ret.is_multigraph_ = is_multigraph_; ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
...@@ -699,6 +362,55 @@ COO COO::AsNumBits(uint8_t bits) const { ...@@ -699,6 +362,55 @@ COO COO::AsNumBits(uint8_t bits) const {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid id array input";
return aten::LT(vids, NumVertices());
}
CSRPtr ImmutableGraph::GetInCSR() const {
if (!in_csr_) {
if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
}
}
return in_csr_;
}
/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr ImmutableGraph::GetOutCSR() const {
if (!out_csr_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
}
}
return out_csr_;
}
/* !\brief Return coo. If not exist, create from csr.*/
COOPtr ImmutableGraph::GetCOO() const {
if (!coo_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
} else {
CHECK(out_csr_) << "Both CSR are missing.";
const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
}
}
return coo_;
}
ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const { ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
if (order.empty()) { if (order.empty()) {
// arbitrary order // arbitrary order
...@@ -783,9 +495,9 @@ ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir, ...@@ -783,9 +495,9 @@ ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir,
const std::string &name) const { const std::string &name) const {
CSRPtr new_incsr, new_outcsr; CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir); std::string shared_mem_name = GetSharedMemName(name, edge_dir);
if (edge_dir == "in") if (edge_dir == std::string("in"))
new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name))); new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == "out") else if (edge_dir == std::string("out"))
new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name))); new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraph(new_incsr, new_outcsr, name); return ImmutableGraph(new_incsr, new_outcsr, name);
} }
......
...@@ -24,24 +24,24 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -24,24 +24,24 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
CHECK_GE(layer1_start, layer0_size); CHECK_GE(layer1_start, layer0_size);
if (fmt == std::string("csr")) { if (fmt == std::string("csr")) {
dgl_id_t first_vid = layer1_start - layer0_size; dgl_id_t first_vid = layer1_start - layer0_size;
CSRMatrix csr = SliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end); auto csr = aten::CSRSliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
if (remap) { if (remap) {
dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data); dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data);
const dgl_id_t first_eid = eid_data[0]; const dgl_id_t first_eid = eid_data[0];
IdArray new_indices = Sub(csr.indices, first_vid); IdArray new_indices = aten::Sub(csr.indices, first_vid);
IdArray new_data = Sub(csr.data, first_eid); IdArray new_data = aten::Sub(csr.data, first_eid);
return {csr.indptr, new_indices, new_data}; return {csr.indptr, new_indices, new_data};
} else { } else {
return {csr.indptr, csr.indices, csr.data}; return {csr.indptr, csr.indices, csr.data};
} }
} else if (fmt == std::string("coo")) { } else if (fmt == std::string("coo")) {
CSRMatrix csr = graph.GetInCSR()->ToCSRMatrix(); auto csr = graph.GetInCSR()->ToCSRMatrix();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data); const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data); const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data); const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data);
int64_t nnz = indptr[layer1_end] - indptr[layer1_start]; int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
IdArray idx = NewIdArray(2 * nnz); IdArray idx = aten::NewIdArray(2 * nnz);
IdArray eid = NewIdArray(nnz); IdArray eid = aten::NewIdArray(nnz);
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t *idx_data = static_cast<int64_t*>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data); dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data);
size_t num_edges = 0; size_t num_edges = 0;
......
...@@ -248,10 +248,10 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -248,10 +248,10 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
int64_t num_edges, int num_hops, bool is_multigraph) { int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf; NodeFlow nf;
uint64_t num_vertices = sub_vers->size(); uint64_t num_vertices = sub_vers->size();
nf.node_mapping = NewIdArray(num_vertices); nf.node_mapping = aten::NewIdArray(num_vertices);
nf.edge_mapping = NewIdArray(num_edges); nf.edge_mapping = aten::NewIdArray(num_edges);
nf.layer_offsets = NewIdArray(num_hops + 1); nf.layer_offsets = aten::NewIdArray(num_hops + 1);
nf.flow_offsets = NewIdArray(num_hops); nf.flow_offsets = aten::NewIdArray(num_hops);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data); dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data); dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data);
...@@ -379,6 +379,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -379,6 +379,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
int num_hops, int num_hops,
size_t num_neighbor, size_t num_neighbor,
const bool add_self_loop) { const bool add_self_loop) {
CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet";
unsigned int time_seed = randseed(); unsigned int time_seed = randseed();
const size_t num_seeds = seeds.size(); const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR(); auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
...@@ -702,8 +703,9 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph, ...@@ -702,8 +703,9 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indices.size(), sub_edge_ids.size()); CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf; NodeFlow nf;
auto sub_csr = CSRPtr(new CSR( auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
VecToIdArray(sub_indptr), VecToIdArray(sub_indices), VecToIdArray(sub_edge_ids))); aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) { if (neighbor_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr)); nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
...@@ -711,10 +713,10 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph, ...@@ -711,10 +713,10 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr)); nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
} }
nf.node_mapping = VecToIdArray(node_mapping); nf.node_mapping = aten::VecToIdArray(node_mapping);
nf.edge_mapping = VecToIdArray(edge_mapping); nf.edge_mapping = aten::VecToIdArray(edge_mapping);
nf.layer_offsets = VecToIdArray(layer_offsets); nf.layer_offsets = aten::VecToIdArray(layer_offsets);
nf.flow_offsets = VecToIdArray(flow_offsets); nf.flow_offsets = aten::VecToIdArray(flow_offsets);
return nf; return nf;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment