Unverified Commit b0d9e7aa authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Separating graph and sparse matrix operations (#699)

* WIP: array refactoring

* WIP: implementation

* wip

* most csr part

* WIP: on coo

* WIP: coo

* finish refactoring immutable graph

* compiled

* fix undefined ndarray copy bug; add COOToCSR when coo has no data array

* fix bug in COOToCSR

* fix bug in CSR constructor

* fix bug in in_edges(vid)

* fix OutEdges bug

* pass test_graph

* pass test_graph

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix bug in CSR constructor

* fix stupid bug

* pass gpu test

* remove debug printout

* fix lint

* rm biparate grpah

* fix lint

* address comments

* fix bug in Clone

* cpp utests
parent f79188da
......@@ -87,6 +87,8 @@ add_definitions(-DENABLE_PARTIAL_FRONTIER=0) # disable minigun partial frontier
# Source file lists
file(GLOB DGL_SRC
src/*.cc
src/array/*.cc
src/array/cpu/*.cc
src/kernel/*.cc
src/kernel/cpu/*.cc
src/runtime/*.cc
......
......@@ -231,6 +231,8 @@ macro(dgl_config_cuda out_variable)
add_definitions(-DDGL_USE_CUDA)
file(GLOB_RECURSE DGL_CUDA_SRC
src/array/cuda/*.cc
src/array/cuda/*.cu
src/kernel/cuda/*.cc
src/kernel/cuda/*.cu
src/runtime/cuda/*.cc
......
......@@ -10,7 +10,9 @@
#define DGL_ARRAY_H_
#include <dgl/runtime/ndarray.h>
#include <algorithm>
#include <vector>
#include <utility>
namespace dgl {
......@@ -23,22 +25,59 @@ typedef dgl::runtime::NDArray IntArray;
typedef dgl::runtime::NDArray FloatArray;
typedef dgl::runtime::NDArray TypeArray;
/*! \brief Create a new id array with given length (on CPU) */
IdArray NewIdArray(int64_t length);
namespace aten {
//////////////////////////////////////////////////////////////////////
// ID array
//////////////////////////////////////////////////////////////////////
/*!
* \brief Create a new id array with given length
* \param length The array length
* \param ctx The array context
* \param nbits The number of integer bits
* \return id array
*/
IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0},
uint8_t nbits = 64);
/*!
* \brief Create a new boolean array with given length (on CPU)
* \note the elements are 64-bit.
* \brief Create a new id array using the given vector data
* \param vec The vector data
* \param nbits The integer bits of the returned array
* \param ctx The array context
* \return the id array
*/
BoolArray NewBoolArray(int64_t length);
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0});
/*! \brief Create a new id array with the given vector data (on CPU) */
IdArray VecToIdArray(const std::vector<dgl_id_t>& vec);
/*!
* \brief Return an array representing a 1D range.
* \param low Lower bound (inclusive).
* \param high Higher bound (exclusive).
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return range array
*/
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
/*! \brief Create a copy of the given array */
/*!
* \brief Return an array full of the given value
* \param val The value to fill.
* \param length Number of elements.
* \param nbits result array's bits (32 or 64)
* \param ctx Device context
* \return the result array
*/
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
/*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr);
/*! \brief Convert the idarray to the given bit width (on CPU) */
/*! \brief Convert the idarray to the given bit width */
IdArray AsNumBits(IdArray arr, uint8_t bits);
/*! \brief Arithmetic functions */
......@@ -57,30 +96,186 @@ IdArray Sub(dgl_id_t lhs, IdArray rhs);
IdArray Mul(dgl_id_t lhs, IdArray rhs);
IdArray Div(dgl_id_t lhs, IdArray rhs);
BoolArray LT(IdArray lhs, dgl_id_t rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2);
/*! \brief Return the data under the index. In numpy notation, A[I] */
int64_t IndexSelect(IdArray array, int64_t index);
IdArray IndexSelect(IdArray array, IdArray index);
/*!
* \brief Relabel the given ids to consecutive ids.
*
* Relabeling is done inplace. The mapping is created from the union
* of the give arrays.
*
* \param arrays The id arrays to relabel.
* \return mapping array M from new id to old id.
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
/*! \brief Plain CSR matrix */
/*!
* \brief Plain CSR matrix
*
* The column indices are 0-based and are not necessarily sorted.
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/
struct CSRMatrix {
IdArray indptr, indices, data;
/*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols;
/*! \brief CSR index arrays */
runtime::NDArray indptr, indices;
/*! \brief data array, could be empty. */
runtime::NDArray data;
};
/*! \brief Plain COO structure */
/*!
* \brief Plain COO structure
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*
* We call a COO matrix is *coalesced* if its row index is sorted.
*/
struct COOMatrix {
IdArray row, col, data;
/*! \brief the dense shape of the matrix */
int64_t num_rows, num_cols;
/*! \brief COO index arrays */
runtime::NDArray row, col;
/*! \brief data array, could be empty. */
runtime::NDArray data;
};
/*! \brief Slice rows of the given matrix and return. */
CSRMatrix SliceRows(const CSRMatrix& csr, int64_t start, int64_t end);
///////////////////////// CSR routines //////////////////////////
/*! \brief Return true if the value (row, col) is non-zero */
bool CSRIsNonZero(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRIsNonZero.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRIsNonZero(CSRMatrix, runtime::NDArray row, runtime::NDArray col);
/*! \brief Return the nnz of the given row */
int64_t CSRGetRowNNZ(CSRMatrix , int64_t row);
runtime::NDArray CSRGetRowNNZ(CSRMatrix , runtime::NDArray row);
/*! \brief Return the column index array of the given row */
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix , int64_t row);
/*! \brief Return the data array of the given row */
runtime::NDArray CSRGetRowData(CSRMatrix , int64_t row);
/*! \brief Convert COO matrix to CSR matrix. */
CSRMatrix ToCSR(const COOMatrix);
/* \brief Get data. The return type is an ndarray due to possible duplicate entries. */
runtime::NDArray CSRGetData(CSRMatrix , int64_t row, int64_t col);
/*!
* \brief Batched implementation of CSRGetData.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
runtime::NDArray CSRGetData(CSRMatrix, runtime::NDArray rows, runtime::NDArray cols);
/*!
* \brief Get the data and the row,col indices for each returned entries.
* \note This operator allows broadcasting (i.e, either row or col can be of length 1).
*/
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix , runtime::NDArray rows, runtime::NDArray cols);
/*! \brief Return a transposed CSR matrix */
CSRMatrix CSRTranspose(CSRMatrix csr);
/*!
* \brief Convert CSR matrix to COO matrix.
* \param csr Input csr matrix
* \param data_as_order If true, the data array in the input csr matrix contains the order
* by which the resulting COO tuples are stored. In this case, the
* data array of the resulting COO matrix will be empty because it
* is essentially a consecutive range.
* \return a coo matrix
*/
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order);
/*!
* \brief Slice rows of the given matrix and return.
* \param csr CSR matrix
* \param start Start row id (inclusive)
* \param end End row id (exclusive)
*
* Examples:
* num_rows = 4
* num_cols = 4
* indptr = [0, 2, 3, 3, 5]
* indices = [1, 0, 2, 3, 1]
*
* After CSRSliceRows(csr, 1, 3)
*
* num_rows = 2
* num_cols = 4
* indptr = [0, 1, 1]
* indices = [2]
*/
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
/*!
* \brief Get the submatrix specified by the row and col ids.
*
* In numpy notation, given matrix M, row index array I, col index array J
* This function returns the submatrix M[I, J].
*
* \param csr The input csr matrix
* \param rows The row index to select
* \param cols The col index to select
* \return submatrix
*/
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
/*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr);
///////////////////////// COO routines //////////////////////////
/*! \return True if the matrix has duplicate entries */
bool COOHasDuplicate(COOMatrix coo);
/*!
* \brief Convert COO matrix to CSR matrix.
*
* If the input COO matrix does not have data array, the data array of
* the result CSR matrix stores a shuffle index for how the entries
* will be reordered in CSR. The i^th entry in the result CSR corresponds
* to the CSR.data[i] th entry in the input COO.
*/
CSRMatrix COOToCSR(COOMatrix coo);
/*! \brief Convert COO matrix to CSR matrix. */
COOMatrix ToCOO(const CSRMatrix);
// inline implementations
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) {
std::copy(vec.begin(), vec.end(), static_cast<int64_t*>(ret->data));
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
return ret.CopyTo(ctx);
}
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_H_
......@@ -128,49 +128,13 @@ class GraphInterface {
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
virtual BoolArray HasVertices(IdArray vids) const {
const auto len = vids->shape[0];
BoolArray rst = NewBoolArray(len);
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const uint64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts)? 1 : 0;
}
return rst;
}
virtual BoolArray HasVertices(IdArray vids) const = 0;
/*! \return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = BoolArray::Empty({rstlen}, src_ids->dtype, src_ids->ctx);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
}
}
return rst;
}
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;
/*!
* \brief Find the predecessors of a vertex.
......
......@@ -69,11 +69,11 @@ class CSR : public GraphInterface {
}
DLContext Context() const override {
return indptr_->ctx;
return adj_.indptr->ctx;
}
uint8_t NumBits() const override {
return indices_->dtype.bits;
return adj_.indices->dtype.bits;
}
bool IsMultigraph() const override;
......@@ -83,15 +83,22 @@ class CSR : public GraphInterface {
}
uint64_t NumVertices() const override {
return indptr_->shape[0] - 1;
return adj_.indptr->shape[0] - 1;
}
uint64_t NumEdges() const override {
return indices_->shape[0];
return adj_.indices->shape[0];
}
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return {};
}
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override;
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override;
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "CSR graph does not support efficient predecessor query."
<< " Please use successors on the reverse CSR graph.";
......@@ -147,8 +154,7 @@ class CSR : public GraphInterface {
}
uint64_t OutDegree(dgl_id_t vid) const override {
const int64_t* indptr_data = static_cast<int64_t*>(indptr_->data);
return indptr_data[vid + 1] - indptr_data[vid];
return aten::CSRGetRowNNZ(adj_, vid);
}
DegreeArray OutDegrees(IdArray vids) const override;
......@@ -165,21 +171,9 @@ class CSR : public GraphInterface {
return Transpose();
}
DGLIdIters SuccVec(dgl_id_t vid) const override {
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end);
}
DGLIdIters SuccVec(dgl_id_t vid) const override;
DGLIdIters OutEdgeVec(dgl_id_t vid) const override {
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(edge_ids_->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end);
}
DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
DGLIdIters PredVec(dgl_id_t vid) const override {
LOG(FATAL) << "CSR graph does not support efficient PredVec."
......@@ -201,7 +195,7 @@ class CSR : public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {indptr_, indices_, edge_ids_};
return {adj_.indptr, adj_.indices, adj_.data};
}
/*! \brief Indicate whether this uses shared memory. */
......@@ -220,8 +214,8 @@ class CSR : public GraphInterface {
* \note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids.
*/
CSRMatrix ToCSRMatrix() const {
return CSRMatrix{indptr_, indices_, edge_ids_};
aten::CSRMatrix ToCSRMatrix() const {
return adj_;
}
/*!
......@@ -247,26 +241,19 @@ class CSR : public GraphInterface {
// member getters
IdArray indptr() const { return indptr_; }
IdArray indptr() const { return adj_.indptr; }
IdArray indices() const { return indices_; }
IdArray indices() const { return adj_.indices; }
IdArray edge_ids() const { return edge_ids_; }
IdArray edge_ids() const { return adj_.data; }
private:
/*! \brief prive default constructor */
CSR() {}
// The CSR arrays.
// - The index is 0-based.
// - The out edges of vertex v is stored from `indices_[indptr_[v]]` to
// `indices_[indptr_[v+1]]` (exclusive).
// - The indices are *not* necessarily sorted.
// TODO(minjie): in the future, we should separate CSR and graph. A general CSR
// is not necessarily a graph, but graph operations could be implemented by
// CSR matrix operations. CSR matrix operations would be backed by different
// devices (CPU, CUDA, ...), while graph interface will not be aware of that.
IdArray indptr_, indices_, edge_ids_;
// The internal CSR adjacency matrix.
// The data field stores edge ids.
aten::CSRMatrix adj_;
// whether the graph is a multi-graph
LazyObject<bool> is_multigraph_;
......@@ -301,11 +288,11 @@ class COO : public GraphInterface {
}
DLContext Context() const override {
return src_->ctx;
return adj_.row->ctx;
}
uint8_t NumBits() const override {
return src_->dtype.bits;
return adj_.row->dtype.bits;
}
bool IsMultigraph() const override;
......@@ -315,23 +302,34 @@ class COO : public GraphInterface {
}
uint64_t NumVertices() const override {
return num_vertices_;
return adj_.num_rows;
}
uint64_t NumEdges() const override {
return src_->shape[0];
return adj_.row->shape[0];
}
bool HasVertex(dgl_id_t vid) const override {
return vid < NumVertices();
}
BoolArray HasVertices(IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph";
return {};
}
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead.";
return false;
}
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const override {
LOG(FATAL) << "COO graph does not support efficient HasEdgeBetween."
<< " Please use CSR graph or AdjList graph instead.";
return {};
}
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const override {
LOG(FATAL) << "COO graph does not support efficient Predecessors."
<< " Please use CSR graph or AdjList graph instead.";
......@@ -356,12 +354,7 @@ class COO : public GraphInterface {
return {};
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override {
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_->data);
return std::make_pair(src_data[eid], dst_data[eid]);
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const override;
EdgeArray FindEdges(IdArray eids) const override;
......@@ -460,15 +453,15 @@ class COO : public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) {
return {HStack(dst_, src_)};
return {aten::HStack(adj_.col, adj_.row)};
} else {
return {HStack(src_, dst_)};
return {aten::HStack(adj_.row, adj_.col)};
}
}
/*! \brief Return the transpose of this COO */
COOPtr Transpose() const {
return COOPtr(new COO(num_vertices_, dst_, src_));
return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
}
/*! \brief Convert this COO to CSR */
......@@ -479,8 +472,8 @@ class COO : public GraphInterface {
* \note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none.
*/
COOMatrix ToCOOMatrix() const {
return COOMatrix{src_, dst_, {}};
aten::COOMatrix ToCOOMatrix() const {
return adj_;
}
/*!
......@@ -511,18 +504,18 @@ class COO : public GraphInterface {
// member getters
IdArray src() const { return src_; }
IdArray src() const { return adj_.row; }
IdArray dst() const { return dst_; }
IdArray dst() const { return adj_.col; }
private:
/* !\brief private default constructor */
COO() {}
/*! \brief number of vertices */
int64_t num_vertices_;
/*! \brief coordinate arrays */
IdArray src_, dst_;
// The internal COO adjacency matrix.
// The data field is empty
aten::COOMatrix adj_;
/*! \brief whether the graph is a multi-graph */
LazyObject<bool> is_multigraph_;
};
......@@ -635,6 +628,8 @@ class ImmutableGraph: public GraphInterface {
return vid < NumVertices();
}
BoolArray HasVertices(IdArray vids) const override;
/*! \return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const override {
if (in_csr_) {
......@@ -644,6 +639,14 @@ class ImmutableGraph: public GraphInterface {
}
}
BoolArray HasEdgesBetween(IdArray src, IdArray dst) const override {
if (in_csr_) {
return in_csr_->HasEdgesBetween(dst, src);
} else {
return GetOutCSR()->HasEdgesBetween(src, dst);
}
}
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
......@@ -910,49 +913,13 @@ class ImmutableGraph: public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
/* !\brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const {
if (!in_csr_) {
if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
}
}
return in_csr_;
}
CSRPtr GetInCSR() const;
/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr GetOutCSR() const {
if (!out_csr_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
}
}
return out_csr_;
}
CSRPtr GetOutCSR() const;
/* !\brief Return coo. If not exist, create from csr.*/
COOPtr GetCOO() const {
if (!coo_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
} else {
CHECK(out_csr_) << "Both CSR are missing.";
const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
}
}
return coo_;
}
COOPtr GetCOO() const;
/*!
* \brief Convert the given graph to an immutable graph.
......@@ -1107,12 +1074,16 @@ template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph): is_multigraph_(is_multigraph) {
indptr_ = NewIdArray(num_vertices + 1);
indices_ = NewIdArray(num_edges);
edge_ids_ = NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(indptr_->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(indices_->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(edge_ids_->data);
// TODO(minjie): this should be changed to a device-agnostic implementation
// in the future
adj_.num_rows = num_vertices;
adj_.num_cols = num_vertices;
adj_.indptr = aten::NewIdArray(num_vertices + 1);
adj_.indices = aten::NewIdArray(num_edges);
adj_.data = aten::NewIdArray(num_edges);
dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
dgl_id_t* edge_ids_data = static_cast<dgl_id_t*>(adj_.data->data);
for (int64_t i = 0; i < num_vertices + 1; ++i)
*(indptr_data++) = *(indptr_begin++);
for (int64_t i = 0; i < num_edges; ++i) {
......
......@@ -292,16 +292,18 @@ struct NDArray::Container {
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
if (data_)
data_->IncRef();
}
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
if (data_)
data_->IncRef();
}
inline void NDArray::reset() {
if (data_ != nullptr) {
if (data_) {
data_->DecRef();
data_ = nullptr;
}
......
......@@ -85,7 +85,7 @@ class ObjectBase(object):
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
handle = __init_by_constructor__(fconstructor, args) # pylint: disable=not-callable
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
......
......@@ -151,6 +151,7 @@ class GraphIndex(object):
readonly_state : bool
New readonly state of current graph index.
"""
# TODO(minjie): very ugly code, should fix this
n_nodes, multigraph, _, src, dst = self.__getstate__()
self.clear_cache()
state = (n_nodes, multigraph, readonly_state, src, dst)
......
/*!
* Copyright (c) 2019 by Contributors
* \file array.cc
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
namespace dgl {
// TODO(minjie): currently these operators are only on CPU.
IdArray NewIdArray(int64_t length) {
return IdArray::Empty({length}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
BoolArray NewBoolArray(int64_t length) {
return BoolArray::Empty({length}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
}
IdArray VecToIdArray(const std::vector<dgl_id_t>& vec) {
IdArray ret = NewIdArray(vec.size());
std::copy(vec.begin(), vec.end(), static_cast<dgl_id_t*>(ret->data));
return ret;
}
IdArray Clone(IdArray arr) {
IdArray ret = NewIdArray(arr->shape[0]);
ret.CopyFrom(arr);
return ret;
}
IdArray AsNumBits(IdArray arr, uint8_t bits) {
if (arr->dtype.bits == bits) {
return arr;
} else {
const int64_t len = arr->shape[0];
IdArray ret = IdArray::Empty({len},
DLDataType{kDLInt, bits, 1}, DLContext{kDLCPU, 0});
if (arr->dtype.bits == 32 && bits == 64) {
const int32_t* arr_data = static_cast<int32_t*>(arr->data);
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else if (arr->dtype.bits == 64 && bits == 32) {
const int64_t* arr_data = static_cast<int64_t*>(arr->data);
int32_t* ret_data = static_cast<int32_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else {
LOG(FATAL) << "Invalid type conversion.";
}
return ret;
}
}
IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] + rhs_data[i];
}
return ret;
}
IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] - rhs_data[i];
}
return ret;
}
IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] * rhs_data[i];
}
return ret;
}
IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] / rhs_data[i];
}
return ret;
}
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] + rhs;
}
return ret;
}
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] - rhs;
}
return ret;
}
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] * rhs;
}
return ret;
}
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret = NewIdArray(lhs->shape[0]);
const dgl_id_t* lhs_data = static_cast<dgl_id_t*>(lhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = lhs_data[i] / rhs;
}
return ret;
}
IdArray Add(dgl_id_t lhs, IdArray rhs) {
return Add(rhs, lhs);
}
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0]);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = lhs - rhs_data[i];
}
return ret;
}
IdArray Mul(dgl_id_t lhs, IdArray rhs) {
return Mul(rhs, lhs);
}
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0]);
const dgl_id_t* rhs_data = static_cast<dgl_id_t*>(rhs->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = lhs / rhs_data[i];
}
return ret;
}
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L);
const dgl_id_t* arr1_data = static_cast<dgl_id_t*>(arr1->data);
const dgl_id_t* arr2_data = static_cast<dgl_id_t*>(arr2->data);
dgl_id_t* ret_data = static_cast<dgl_id_t*>(ret->data);
for (int64_t i = 0; i < L; ++i) {
ret_data[i] = arr1_data[i];
ret_data[i + L] = arr2_data[i];
}
return ret;
}
CSRMatrix SliceRows(const CSRMatrix& csr, int64_t start, int64_t end) {
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data);
const dgl_id_t* data = static_cast<dgl_id_t*>(csr.data->data);
const int64_t num_rows = end - start;
const int64_t nnz = indptr[end] - indptr[start];
CSRMatrix ret;
ret.indptr = NewIdArray(num_rows + 1);
ret.indices = NewIdArray(nnz);
ret.data = NewIdArray(nnz);
dgl_id_t* r_indptr = static_cast<dgl_id_t*>(ret.indptr->data);
dgl_id_t* r_indices = static_cast<dgl_id_t*>(ret.indices->data);
dgl_id_t* r_data = static_cast<dgl_id_t*>(ret.data->data);
for (int64_t i = start; i < end + 1; ++i) {
r_indptr[i - start] = indptr[i] - indptr[start];
}
std::copy(indices + indptr[start], indices + indptr[end], r_indices);
std::copy(data + indptr[start], data + indptr[end], r_data);
return ret;
}
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/arith.h
* \brief Arithmetic functors
*/
#ifndef DGL_ARRAY_ARITH_H_
#define DGL_ARRAY_ARITH_H_
namespace dgl {
namespace aten {
namespace arith {
struct Add {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 + t2;
}
};
struct Sub {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 - t2;
}
};
struct Mul {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 * t2;
}
};
struct Div {
template <typename T>
inline static T Call(const T& t1, const T& t2) {
return t1 / t2;
}
};
struct LT {
template <typename T>
inline static bool Call(const T& t1, const T& t2) {
return t1 < t2;
}
};
} // namespace arith
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_ARITH_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/array.cc
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include "../c_api_common.h"
#include "./array_op.h"
#include "./arith.h"
#include "./common.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx);
}
IdArray Clone(IdArray arr) {
IdArray ret = NewIdArray(arr->shape[0], arr->ctx, arr->dtype.bits);
ret.CopyFrom(arr);
return ret;
}
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
if (nbits == 32) {
ret = impl::Range<XPU, int32_t>(low, high, ctx);
} else if (nbits == 64) {
ret = impl::Range<XPU, int64_t>(low, high, ctx);
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
});
return ret;
}
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH(ctx.device_type, XPU, {
if (nbits == 32) {
ret = impl::Full<XPU, int32_t>(val, length, ctx);
} else if (nbits == 64) {
ret = impl::Full<XPU, int64_t>(val, length, ctx);
} else {
LOG(FATAL) << "Only int32 or int64 is supported.";
}
});
return ret;
}
IdArray AsNumBits(IdArray arr, uint8_t bits) {
IdArray ret;
ATEN_XPU_SWITCH(arr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(arr->dtype, IdType, {
ret = impl::AsNumBits<XPU, IdType>(arr, bits);
});
});
return ret;
}
IdArray Add(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Add>(lhs, rhs);
});
});
return ret;
}
IdArray Sub(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Mul>(lhs, rhs);
});
});
return ret;
}
IdArray Div(IdArray lhs, dgl_id_t rhs) {
IdArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
IdArray Add(dgl_id_t lhs, IdArray rhs) {
return Add(rhs, lhs);
}
IdArray Sub(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Sub>(lhs, rhs);
});
});
return ret;
}
IdArray Mul(dgl_id_t lhs, IdArray rhs) {
return Mul(rhs, lhs);
}
IdArray Div(dgl_id_t lhs, IdArray rhs) {
IdArray ret;
ATEN_XPU_SWITCH(rhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::Div>(lhs, rhs);
});
});
return ret;
}
BoolArray LT(IdArray lhs, dgl_id_t rhs) {
BoolArray ret;
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::BinaryElewise<XPU, IdType, arith::LT>(lhs, rhs);
});
});
return ret;
}
IdArray HStack(IdArray lhs, IdArray rhs) {
IdArray ret;
CHECK_EQ(lhs->ctx, rhs->ctx) << "Both operands should have the same device context";
CHECK_EQ(lhs->dtype, rhs->dtype) << "Both operands should have the same dtype";
ATEN_XPU_SWITCH(lhs->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, {
ret = impl::HStack<XPU, IdType>(lhs, rhs);
});
});
return ret;
}
IdArray IndexSelect(IdArray array, IdArray index) {
IdArray ret;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
});
});
return ret;
}
int64_t IndexSelect(IdArray array, int64_t index) {
int64_t ret = 0;
ATEN_XPU_SWITCH(array->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
ret = impl::IndexSelect<XPU, IdType>(array, index);
});
});
return ret;
}
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray ret;
ATEN_XPU_SWITCH(arrays[0]->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(arrays[0]->dtype, IdType, {
ret = impl::Relabel_<XPU, IdType>(arrays);
});
});
return ret;
}
///////////////////////// CSR routines //////////////////////////
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
}
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRIsNonZero<XPU, IdType>(csr, row, col);
});
return ret;
}
bool CSRHasDuplicate(CSRMatrix csr) {
bool ret = false;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRHasDuplicate<XPU, IdType>(csr);
});
return ret;
}
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
int64_t ret = 0;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray row) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, {
ret = impl::CSRGetRowColumnIndices<XPU, IdType>(csr, row);
});
return ret;
}
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetRowData<XPU, IdType, DType>(csr, row);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, row, col);
});
return ret;
}
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
NDArray ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetData<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
std::vector<NDArray> CSRGetDataAndIndices(
CSRMatrix csr, NDArray rows, NDArray cols) {
std::vector<NDArray> ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRGetDataAndIndices<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
CSRMatrix CSRTranspose(CSRMatrix csr) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRTranspose<XPU, IdType, DType>(csr);
});
return ret;
}
COOMatrix CSRToCOO(CSRMatrix csr, bool data_as_order) {
COOMatrix ret;
if (data_as_order) {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOODataAsOrder<XPU, IdType>(csr);
});
});
} else {
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, {
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, {
ret = impl::CSRToCOO<XPU, IdType>(csr);
});
});
}
return ret;
}
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, start, end);
});
return ret;
}
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceRows<XPU, IdType, DType>(csr, rows);
});
return ret;
}
CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
ret = impl::CSRSliceMatrix<XPU, IdType, DType>(csr, rows, cols);
});
return ret;
}
///////////////////////// COO routines //////////////////////////
bool COOHasDuplicate(COOMatrix coo) {
bool ret = false;
ATEN_COO_IDX_SWITCH(coo, XPU, IdType, {
ret = impl::COOHasDuplicate<XPU, IdType>(coo);
});
return ret;
}
CSRMatrix COOToCSR(COOMatrix coo) {
CSRMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, DType, {
ret = impl::COOToCSR<XPU, IdType, DType>(coo);
});
return ret;
}
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/array_op.h
* \brief Array operator templates
*/
#ifndef DGL_ARRAY_ARRAY_OP_H_
#define DGL_ARRAY_ARRAY_OP_H_
#include <dgl/array.h>
#include <vector>
namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2);
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index);
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr);
// Convert CSR to COO using data array as order
template <DLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix COOToCSR(COOMatrix coo);
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_ARRAY_OP_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/common.h
* \brief Array operator common utilities
*/
#ifndef DGL_ARRAY_COMMON_H_
#define DGL_ARRAY_COMMON_H_
namespace dgl {
namespace aten {
#define ATEN_XPU_SWITCH(val, XPU, ...) \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Device type: " << (val) << " is not supported."; \
}
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t IdType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "ID can Only be int32 or int64"; \
}
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) \
if ((val).code == kDLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "CSR matrix data can only be int32 or int64"; \
}
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_CSR_SWITCH(csr, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_CSR_IDX_SWITCH(csr, XPU, IdType, ...) \
ATEN_XPU_SWITCH(csr.indptr->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(csr.indptr->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context, index type and data type
// TODO(minjie): In our current use cases, data type and id type are the
// same. For example, data array is used to store edge ids.
#define ATEN_COO_SWITCH(coo, XPU, IdType, DType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
typedef IdType DType; \
{__VA_ARGS__} \
}); \
});
// Macro to dispatch according to device context and index type
#define ATEN_COO_IDX_SWITCH(coo, XPU, IdType, ...) \
ATEN_XPU_SWITCH(coo.row->ctx.device_type, XPU, { \
ATEN_ID_TYPE_SWITCH(coo.row->dtype, IdType, { \
{__VA_ARGS__} \
}); \
});
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_COMMON_H_
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/array_op_impl.cc
* \brief Array operator CPU implementation
*/
#include <dgl/array.h>
#include <numeric>
#include "../arith.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
///////////////////////////// AsNumBits /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
if (sizeof(IdType) * 8 == bits) {
return arr;
}
const int64_t len = arr->shape[0];
IdArray ret = NewIdArray(len, arr->ctx, bits);
const IdType* arr_data = static_cast<IdType*>(arr->data);
if (bits == 32) {
int32_t* ret_data = static_cast<int32_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
} else {
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
ret_data[i] = arr_data[i];
}
}
return ret;
}
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits);
///////////////////////////// BinaryElewise /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs_data[i]);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < lhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs_data[i], rhs);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < rhs->shape[0]; ++i) {
ret_data[i] = Op::Call(lhs, rhs_data[i]);
}
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
///////////////////////////// HStack /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray HStack(IdArray arr1, IdArray arr2) {
CHECK_EQ(arr1->shape[0], arr2->shape[0]);
const int64_t L = arr1->shape[0];
IdArray ret = NewIdArray(2 * L);
const IdType* arr1_data = static_cast<IdType*>(arr1->data);
const IdType* arr2_data = static_cast<IdType*>(arr2->data);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < L; ++i) {
ret_data[i] = arr1_data[i];
ret_data[i + L] = arr2_data[i];
}
return ret;
}
template IdArray HStack<kDLCPU, int32_t>(IdArray arr1, IdArray arr2);
template IdArray HStack<kDLCPU, int64_t>(IdArray arr1, IdArray arr2);
///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx) {
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
std::fill(ret_data, ret_data + length, val);
return ret;
}
template IdArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
///////////////////////////// Range /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
std::iota(ret_data, ret_data + high - low, low);
return ret;
}
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
///////////////////////////// IndexSelect /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray IndexSelect(IdArray array, IdArray index) {
const IdType* array_data = static_cast<IdType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
IdArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
IdType* ret_data = static_cast<IdType*>(ret->data);
for (int64_t i = 0; i < len; ++i) {
CHECK_LT(idx_data[i], arr_len) << "Index out of range.";
ret_data[i] = array_data[idx_data[i]];
}
return ret;
}
template IdArray IndexSelect<kDLCPU, int32_t>(IdArray, IdArray);
template IdArray IndexSelect<kDLCPU, int64_t>(IdArray, IdArray);
template <DLDeviceType XPU, typename IdType>
int64_t IndexSelect(IdArray array, int64_t index) {
const IdType* data = static_cast<IdType*>(array->data);
return data[index];
}
template int64_t IndexSelect<kDLCPU, int32_t>(IdArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(IdArray array, int64_t index);
///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build map & relabel
IdType newid = 0;
std::unordered_map<IdType, IdType> oldv2newv;
for (IdArray arr : arrays) {
for (int64_t i = 0; i < arr->shape[0]; ++i) {
const IdType id = static_cast<IdType*>(arr->data)[i];
if (!oldv2newv.count(id)) {
oldv2newv[id] = newid++;
}
static_cast<IdType*>(arr->data)[i] = oldv2newv[id];
}
}
// map array
IdArray maparr = NewIdArray(newid);
IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first;
}
return maparr;
}
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
} // namespace impl
} // namespace aten
} // namespace dgl
This diff is collapsed.
......@@ -12,6 +12,16 @@
#include <algorithm>
#include <vector>
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLDataType& ty) {
return os << "code=" << ty.code << ",bits=" << ty.bits << "lanes=" << ty.lanes;
}
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
......@@ -19,7 +29,7 @@ inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
return os << "" << ctx.device_type << ":" << ctx.device_id << "";
return os << ctx.device_type << ":" << ctx.device_id;
}
namespace dgl {
......@@ -45,8 +55,7 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ctx.device_type == kDLCPU && arr->ndim == 1
&& arr->dtype.code == kDLInt && arr->dtype.bits == 64;
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
......
......@@ -138,9 +138,9 @@ ImmutableGraph GraphOp::DisjointUnion(std::vector<const ImmutableGraph *> graphs
num_nodes += gr->NumVertices();
num_edges += gr->NumEdges();
}
IdArray indptr_arr = NewIdArray(num_nodes + 1);
IdArray indices_arr = NewIdArray(num_edges);
IdArray edge_ids_arr = NewIdArray(num_edges);
IdArray indptr_arr = aten::NewIdArray(num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(num_edges);
IdArray edge_ids_arr = aten::NewIdArray(num_edges);
dgl_id_t* indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
......@@ -207,9 +207,9 @@ std::vector<ImmutableGraph> GraphOp::DisjointPartitionBySizes(const ImmutableGra
const int64_t end_pos = cumsum[i + 1];
const int64_t g_num_nodes = sizes_data[i];
const int64_t g_num_edges = indptr[end_pos] - indptr[start_pos];
IdArray indptr_arr = NewIdArray(g_num_nodes + 1);
IdArray indices_arr = NewIdArray(g_num_edges);
IdArray edge_ids_arr = NewIdArray(g_num_edges);
IdArray indptr_arr = aten::NewIdArray(g_num_nodes + 1);
IdArray indices_arr = aten::NewIdArray(g_num_edges);
IdArray edge_ids_arr = aten::NewIdArray(g_num_edges);
dgl_id_t* g_indptr = static_cast<dgl_id_t*>(indptr_arr->data);
dgl_id_t* g_indices = static_cast<dgl_id_t*>(indices_arr->data);
dgl_id_t* g_edge_ids = static_cast<dgl_id_t*>(edge_ids_arr->data);
......@@ -329,13 +329,13 @@ Graph GraphOp::ToBidirectedMutableGraph(const GraphInterface* g) {
for (dgl_id_t v = u; v < g->NumVertices(); ++v) {
const auto new_n_e = std::max(n_e[u][v], n_e[v][u]);
if (new_n_e > 0) {
IdArray us = NewIdArray(new_n_e);
IdArray us = aten::NewIdArray(new_n_e);
dgl_id_t* us_data = static_cast<dgl_id_t*>(us->data);
std::fill(us_data, us_data + new_n_e, u);
if (u == v) {
bg.AddEdges(us, us);
} else {
IdArray vs = NewIdArray(new_n_e);
IdArray vs = aten::NewIdArray(new_n_e);
dgl_id_t* vs_data = static_cast<dgl_id_t*>(vs->data);
std::fill(vs_data, vs_data + new_n_e, v);
bg.AddEdges(us, vs);
......@@ -380,8 +380,8 @@ ImmutableGraph GraphOp::ToBidirectedImmutableGraph(const GraphInterface* g) {
}
}
IdArray srcs_array = VecToIdArray(srcs);
IdArray dsts_array = VecToIdArray(dsts);
IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = aten::VecToIdArray(dsts);
COOPtr coo(new COO(g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()));
return ImmutableGraph(coo);
}
......
This diff is collapsed.
......@@ -24,24 +24,24 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
CHECK_GE(layer1_start, layer0_size);
if (fmt == std::string("csr")) {
dgl_id_t first_vid = layer1_start - layer0_size;
CSRMatrix csr = SliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
auto csr = aten::CSRSliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
if (remap) {
dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data);
const dgl_id_t first_eid = eid_data[0];
IdArray new_indices = Sub(csr.indices, first_vid);
IdArray new_data = Sub(csr.data, first_eid);
IdArray new_indices = aten::Sub(csr.indices, first_vid);
IdArray new_data = aten::Sub(csr.data, first_eid);
return {csr.indptr, new_indices, new_data};
} else {
return {csr.indptr, csr.indices, csr.data};
}
} else if (fmt == std::string("coo")) {
CSRMatrix csr = graph.GetInCSR()->ToCSRMatrix();
auto csr = graph.GetInCSR()->ToCSRMatrix();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data);
int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
IdArray idx = NewIdArray(2 * nnz);
IdArray eid = NewIdArray(nnz);
IdArray idx = aten::NewIdArray(2 * nnz);
IdArray eid = aten::NewIdArray(nnz);
int64_t *idx_data = static_cast<int64_t*>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data);
size_t num_edges = 0;
......
......@@ -248,10 +248,10 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf;
uint64_t num_vertices = sub_vers->size();
nf.node_mapping = NewIdArray(num_vertices);
nf.edge_mapping = NewIdArray(num_edges);
nf.layer_offsets = NewIdArray(num_hops + 1);
nf.flow_offsets = NewIdArray(num_hops);
nf.node_mapping = aten::NewIdArray(num_vertices);
nf.edge_mapping = aten::NewIdArray(num_edges);
nf.layer_offsets = aten::NewIdArray(num_hops + 1);
nf.flow_offsets = aten::NewIdArray(num_hops);
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data);
......@@ -379,6 +379,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
int num_hops,
size_t num_neighbor,
const bool add_self_loop) {
CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet";
unsigned int time_seed = randseed();
const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
......@@ -702,8 +703,9 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf;
auto sub_csr = CSRPtr(new CSR(
VecToIdArray(sub_indptr), VecToIdArray(sub_indices), VecToIdArray(sub_edge_ids)));
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) {
nf.graph = GraphPtr(new ImmutableGraph(sub_csr, nullptr));
......@@ -711,10 +713,10 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
nf.graph = GraphPtr(new ImmutableGraph(nullptr, sub_csr));
}
nf.node_mapping = VecToIdArray(node_mapping);
nf.edge_mapping = VecToIdArray(edge_mapping);
nf.layer_offsets = VecToIdArray(layer_offsets);
nf.flow_offsets = VecToIdArray(flow_offsets);
nf.node_mapping = aten::VecToIdArray(node_mapping);
nf.edge_mapping = aten::VecToIdArray(edge_mapping);
nf.layer_offsets = aten::VecToIdArray(layer_offsets);
nf.flow_offsets = aten::VecToIdArray(flow_offsets);
return nf;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment