"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4433a5b2e7570ca86b08e0c229e3ac5bed8046bb"
Unverified Commit 65e1ba4f authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

Sort csr (#886)

* sort

* sort in parallel.

* fix a bug in sorting adj

* rename.

* add more comments.

* accelerate GetData

* fix tests.

* avoid sorting multiple times.

* add test.

* change back.

* sort.

* add sort_csr.

* Fix a bug.

* fix.

* revert modifcation.

* rename

* speed up EdgeIds.
parent 54e1ef2e
...@@ -141,6 +141,8 @@ struct CSRMatrix { ...@@ -141,6 +141,8 @@ struct CSRMatrix {
runtime::NDArray indptr, indices; runtime::NDArray indptr, indices;
/*! \brief data array, could be empty. */ /*! \brief data array, could be empty. */
runtime::NDArray data; runtime::NDArray data;
/*! \brief indicate that the edges are stored in the sorted order. */
bool sorted;
}; };
/*! /*!
...@@ -249,6 +251,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray ...@@ -249,6 +251,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
/*! \return True if the matrix has duplicate entries */ /*! \return True if the matrix has duplicate entries */
bool CSRHasDuplicate(CSRMatrix csr); bool CSRHasDuplicate(CSRMatrix csr);
/*! Sort the columns in each row in the ascending order. */
void CSRSort(CSRMatrix csr);
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
/*! \return True if the matrix has duplicate entries */ /*! \return True if the matrix has duplicate entries */
......
...@@ -346,6 +346,15 @@ class GraphInterface : public runtime::Object { ...@@ -346,6 +346,15 @@ class GraphInterface : public runtime::Object {
*/ */
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0; virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;
/*!
* \brief Sort the columns in CSR.
*
* This sorts the columns in each row based on the column Ids.
* The edge ids should be sorted accordingly.
*/
virtual void SortCSR() {
}
static constexpr const char* _type_key = "graph.Graph"; static constexpr const char* _type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
}; };
......
...@@ -240,9 +240,16 @@ class CSR : public GraphInterface { ...@@ -240,9 +240,16 @@ class CSR : public GraphInterface {
IdArray edge_ids() const { return adj_.data; } IdArray edge_ids() const { return adj_.data; }
void SortCSR() {
if (adj_.sorted)
return;
aten::CSRSort(adj_);
adj_.sorted = true;
}
private: private:
/*! \brief prive default constructor */ /*! \brief prive default constructor */
CSR() {} CSR() {adj_.sorted = false;}
// The internal CSR adjacency matrix. // The internal CSR adjacency matrix.
// The data field stores edge ids. // The data field stores edge ids.
...@@ -951,6 +958,11 @@ class ImmutableGraph: public GraphInterface { ...@@ -951,6 +958,11 @@ class ImmutableGraph: public GraphInterface {
*/ */
ImmutableGraphPtr Reverse() const; ImmutableGraphPtr Reverse() const;
void SortCSR() {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
}
protected: protected:
/* !\brief internal default constructor */ /* !\brief internal default constructor */
ImmutableGraph() {} ImmutableGraph() {}
......
...@@ -901,12 +901,17 @@ class DGLGraph(DGLBaseGraph): ...@@ -901,12 +901,17 @@ class DGLGraph(DGLBaseGraph):
node_frame=None, node_frame=None,
edge_frame=None, edge_frame=None,
multigraph=None, multigraph=None,
readonly=False): readonly=False,
sort_csr=False):
# graph # graph
if isinstance(graph_data, DGLGraph): if isinstance(graph_data, DGLGraph):
gidx = graph_data._graph gidx = graph_data._graph
if sort_csr:
gidx.sort_csr()
else: else:
gidx = graph_index.create_graph_index(graph_data, multigraph, readonly) gidx = graph_index.create_graph_index(graph_data, multigraph, readonly)
if sort_csr:
gidx.sort_csr()
super(DGLGraph, self).__init__(gidx) super(DGLGraph, self).__init__(gidx)
# node and edge frame # node and edge frame
......
...@@ -421,6 +421,16 @@ class GraphIndex(ObjectBase): ...@@ -421,6 +421,16 @@ class GraphIndex(ObjectBase):
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2))
return src, dst, eid return src, dst, eid
def sort_csr(self):
"""Sort the CSR matrix in the graph index.
By default, when the CSR matrix is created, the edges may be stored
in an arbitrary order. Sometimes, we want to sort them to accelerate
some computation. For example, `has_edges_between` can be much faster
on a giant adjacency matrix if the edges in the matrix is sorted.
"""
_CAPI_DGLSortAdj(self)
@utils.cached_member(cache='_cache', prefix='edges') @utils.cached_member(cache='_cache', prefix='edges')
def edges(self, order=None): def edges(self, order=None):
"""Return all the edges """Return all the edges
......
...@@ -364,6 +364,12 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -364,6 +364,12 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, NDArray rows, NDArray cols) {
return ret; return ret;
} }
void CSRSort(CSRMatrix csr) {
ATEN_CSR_SWITCH(csr, XPU, IdType, DType, {
impl::CSRSort<XPU, IdType, DType>(csr);
});
}
///////////////////////// COO routines ////////////////////////// ///////////////////////// COO routines //////////////////////////
bool COOHasDuplicate(COOMatrix coo) { bool COOHasDuplicate(COOMatrix coo) {
......
...@@ -96,6 +96,9 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows); ...@@ -96,6 +96,9 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
void CSRSort(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo); bool COOHasDuplicate(COOMatrix coo);
......
...@@ -95,11 +95,17 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -95,11 +95,17 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col; CHECK(col >= 0 && col < csr.num_cols) << "Invalid col index: " << col;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
if (csr.sorted) {
const IdType *start = indices_data + indptr_data[row];
const IdType *end = indices_data + indptr_data[row + 1];
return std::binary_search(start, end, col);
} else {
for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) { for (IdType i = indptr_data[row]; i < indptr_data[row + 1]; ++i) {
if (indices_data[i] == col) { if (indices_data[i] == col) {
return true; return true;
} }
} }
}
return false; return false;
} }
...@@ -209,6 +215,27 @@ template NDArray CSRGetRowData<kDLCPU, int64_t, int64_t>(CSRMatrix, int64_t); ...@@ -209,6 +215,27 @@ template NDArray CSRGetRowData<kDLCPU, int64_t, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData ///////////////////////////// ///////////////////////////// CSRGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
void CollectDataFromSorted(const IdType *indices_data, const DType *data,
const IdType start, const IdType end, const IdType col,
std::vector<DType> *ret_vec) {
const IdType *start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched
// columns.
for (; it != end_ptr; it++) {
// If the col exist
if (*it == col) {
IdType idx = it - indices_data;
ret_vec->push_back(data[idx]);
} else {
// If we find a column that is different, we can stop searching now.
break;
}
}
}
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) { NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
CHECK(CSRHasData(csr)) << "missing data array"; CHECK(CSRHasData(csr)) << "missing data array";
...@@ -219,11 +246,17 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) { ...@@ -219,11 +246,17 @@ NDArray CSRGetData(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data); const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
const DType* data = static_cast<DType*>(csr.data->data); const DType* data = static_cast<DType*>(csr.data->data);
if (csr.sorted) {
CollectDataFromSorted<XPU, IdType, DType>(indices_data, data,
indptr_data[row], indptr_data[row + 1],
col, &ret_vec);
} else {
for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) { for (IdType i = indptr_data[row]; i < indptr_data[row+1]; ++i) {
if (indices_data[i] == col) { if (indices_data[i] == col) {
ret_vec.push_back(data[i]); ret_vec.push_back(data[i]);
} }
} }
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx); return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
} }
...@@ -255,12 +288,18 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) { ...@@ -255,12 +288,18 @@ NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id; CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id; CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
if (csr.sorted) {
CollectDataFromSorted<XPU, IdType, DType>(indices_data, data,
indptr_data[row_id], indptr_data[row_id + 1],
col_id, &ret_vec);
} else {
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) { for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
if (indices_data[i] == col_id) { if (indices_data[i] == col_id) {
ret_vec.push_back(data[i]); ret_vec.push_back(data[i]);
} }
} }
} }
}
return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx); return VecToNDArray(ret_vec, csr.data->dtype, csr.data->ctx);
} }
...@@ -270,6 +309,29 @@ template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(CSRMatrix csr, NDArray row ...@@ -270,6 +309,29 @@ template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(CSRMatrix csr, NDArray row
///////////////////////////// CSRGetDataAndIndices ///////////////////////////// ///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename DType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const DType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *col_vec,
std::vector<DType> *ret_vec) {
const IdType *start_ptr = indices_data + start;
const IdType *end_ptr = indices_data + end;
auto it = std::lower_bound(start_ptr, end_ptr, col);
// This might be a multi-graph. We need to collect all of the matched
// columns.
for (; it != end_ptr; it++) {
// If the col exist
if (*it == col) {
IdType idx = it - indices_data;
col_vec->push_back(indices_data[idx]);
ret_vec->push_back(data[idx]);
} else {
// If we find a column that is different, we can stop searching now.
break;
}
}
}
template <DLDeviceType XPU, typename IdType, typename DType> template <DLDeviceType XPU, typename IdType, typename DType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) { std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
CHECK(CSRHasData(csr)) << "missing data array"; CHECK(CSRHasData(csr)) << "missing data array";
...@@ -297,6 +359,18 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c ...@@ -297,6 +359,18 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
const IdType row_id = row_data[i], col_id = col_data[j]; const IdType row_id = row_data[i], col_id = col_data[j];
CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id; CHECK(row_id >= 0 && row_id < csr.num_rows) << "Invalid row index: " << row_id;
CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id; CHECK(col_id >= 0 && col_id < csr.num_cols) << "Invalid col index: " << col_id;
if (csr.sorted) {
// Here we collect col indices and data.
CollectDataIndicesFromSorted<XPU, IdType, DType>(indices_data, data,
indptr_data[row_id],
indptr_data[row_id + 1],
col_id, &ret_cols,
&ret_data);
// We need to add row Ids.
while (ret_rows.size() < ret_data.size()) {
ret_rows.push_back(row_id);
}
} else {
for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) { for (IdType i = indptr_data[row_id]; i < indptr_data[row_id+1]; ++i) {
if (indices_data[i] == col_id) { if (indices_data[i] == col_id) {
ret_rows.push_back(row_id); ret_rows.push_back(row_id);
...@@ -305,6 +379,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c ...@@ -305,6 +379,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
} }
} }
} }
}
return {VecToIdArray(ret_rows, csr.indptr->dtype.bits, csr.indptr->ctx), return {VecToIdArray(ret_rows, csr.indptr->dtype.bits, csr.indptr->ctx),
VecToIdArray(ret_cols, csr.indptr->dtype.bits, csr.indptr->ctx), VecToIdArray(ret_cols, csr.indptr->dtype.bits, csr.indptr->ctx),
...@@ -548,6 +623,42 @@ template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t, int32_t>( ...@@ -548,6 +623,42 @@ template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t, int32_t>(
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t, int64_t>( template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols); CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType, typename DType>
void CSRSort(CSRMatrix csr) {
typedef std::pair<IdType, DType> shuffle_ele;
int64_t num_rows = csr.num_rows;
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
IdType* indices_data = static_cast<IdType*>(csr.indices->data);
DType* eid_data = static_cast<DType*>(csr.data->data);
#pragma omp parallel
{
std::vector<shuffle_ele> reorder_vec;
#pragma omp for
for (int64_t row = 0; row < num_rows; row++) {
int64_t num_cols = indptr_data[row + 1] - indptr_data[row];
IdType *col = indices_data + indptr_data[row];
DType *eid = eid_data + indptr_data[row];
reorder_vec.resize(num_cols);
for (int64_t i = 0; i < num_cols; i++) {
reorder_vec[i].first = col[i];
reorder_vec[i].second = eid[i];
}
std::sort(reorder_vec.begin(), reorder_vec.end(),
[](const shuffle_ele &e1, const shuffle_ele &e2) {
return e1.first < e2.first;
});
for (int64_t i = 0; i < num_cols; i++) {
col[i] = reorder_vec[i].first;
eid[i] = reorder_vec[i].second;
}
}
}
}
template void CSRSort<kDLCPU, int64_t, int64_t>(CSRMatrix csr);
template void CSRSort<kDLCPU, int32_t, int32_t>(CSRMatrix csr);
///////////////////////////// COOHasDuplicate ///////////////////////////// ///////////////////////////// COOHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
......
...@@ -346,4 +346,10 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges") ...@@ -346,4 +346,10 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
*rv = subg->induced_edges; *rv = subg->induced_edges;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
g->SortCSR();
});
} // namespace dgl } // namespace dgl
...@@ -56,6 +56,7 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph) ...@@ -56,6 +56,7 @@ CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
aten::NewIdArray(num_vertices + 1), aten::NewIdArray(num_vertices + 1),
aten::NewIdArray(num_edges), aten::NewIdArray(num_edges),
aten::NewIdArray(num_edges)}; aten::NewIdArray(num_edges)};
adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
...@@ -65,6 +66,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { ...@@ -65,6 +66,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1; const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
...@@ -75,6 +77,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) ...@@ -75,6 +77,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
CHECK_EQ(indices->shape[0], edge_ids->shape[0]); CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1; const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
...@@ -93,6 +96,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -93,6 +96,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
adj_.indptr.CopyFrom(indptr); adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices); adj_.indices.CopyFrom(indices);
adj_.data.CopyFrom(edge_ids); adj_.data.CopyFrom(edge_ids);
adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
...@@ -112,6 +116,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, ...@@ -112,6 +116,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
adj_.indptr.CopyFrom(indptr); adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices); adj_.indices.CopyFrom(indices);
adj_.data.CopyFrom(edge_ids); adj_.data.CopyFrom(edge_ids);
adj_.sorted = false;
} }
CSR::CSR(const std::string &shared_mem_name, CSR::CSR(const std::string &shared_mem_name,
...@@ -122,6 +127,7 @@ CSR::CSR(const std::string &shared_mem_name, ...@@ -122,6 +127,7 @@ CSR::CSR(const std::string &shared_mem_name,
adj_.num_cols = num_verts; adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory( std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, false); shared_mem_name, num_verts, num_edges, false);
adj_.sorted = false;
} }
bool CSR::IsMultigraph() const { bool CSR::IsMultigraph() const {
...@@ -196,6 +202,7 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const { ...@@ -196,6 +202,7 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids); const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context()); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids)); CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
subcsr->adj_.sorted = this->adj_.sorted;
Subgraph subg; Subgraph subg;
subg.graph = subcsr; subg.graph = subcsr;
subg.induced_vertices = vids; subg.induced_vertices = vids;
...@@ -230,6 +237,7 @@ CSR CSR::CopyToSharedMem(const std::string &name) const { ...@@ -230,6 +237,7 @@ CSR CSR::CopyToSharedMem(const std::string &name) const {
CHECK(name == shared_mem_name_); CHECK(name == shared_mem_name_);
return *this; return *this;
} else { } else {
// TODO(zhengda) we need to set sorted_ properly.
return CSR(adj_.indptr, adj_.indices, adj_.data, name); return CSR(adj_.indptr, adj_.indices, adj_.data, name);
} }
} }
......
...@@ -338,6 +338,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -338,6 +338,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length"; << "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
} }
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
...@@ -349,10 +350,13 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -349,10 +350,13 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length"; << "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
sorted_ = false;
} }
explicit CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) explicit CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) {} : BaseHeteroGraph(metagraph), adj_(csr) {
sorted_ = false;
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
return 0; return 0;
...@@ -628,6 +632,9 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -628,6 +632,9 @@ class UnitGraph::CSR : public BaseHeteroGraph {
/*! \brief multi-graph flag */ /*! \brief multi-graph flag */
Lazy<bool> is_multigraph_; Lazy<bool> is_multigraph_;
/*! \brief indicate that the edges are stored in the sorted order. */
bool sorted_;
}; };
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
......
...@@ -50,9 +50,8 @@ def gen_by_mutation(): ...@@ -50,9 +50,8 @@ def gen_by_mutation():
g.add_edges(src, dst) g.add_edges(src, dst)
return g return g
def gen_from_data(data, readonly): def gen_from_data(data, readonly, sort):
g = dgl.DGLGraph(data, readonly=readonly) return dgl.DGLGraph(data, readonly=readonly, sort_csr=True)
return g
def test_query(): def test_query():
def _test_one(g): def _test_one(g):
...@@ -201,15 +200,16 @@ def test_query(): ...@@ -201,15 +200,16 @@ def test_query():
_test_csr_one(g) _test_csr_one(g)
_test(gen_by_mutation()) _test(gen_by_mutation())
_test(gen_from_data(elist_input(), False)) _test(gen_from_data(elist_input(), False, False))
_test(gen_from_data(elist_input(), True)) _test(gen_from_data(elist_input(), True, False))
_test(gen_from_data(nx_input(), False)) _test(gen_from_data(elist_input(), True, True))
_test(gen_from_data(nx_input(), True)) _test(gen_from_data(nx_input(), False, False))
_test(gen_from_data(scipy_coo_input(), False)) _test(gen_from_data(nx_input(), True, False))
_test(gen_from_data(scipy_coo_input(), True)) _test(gen_from_data(scipy_coo_input(), False, False))
_test(gen_from_data(scipy_coo_input(), True, False))
_test_csr(gen_from_data(scipy_csr_input(), False))
_test_csr(gen_from_data(scipy_csr_input(), True)) _test_csr(gen_from_data(scipy_csr_input(), False, False))
_test_csr(gen_from_data(scipy_csr_input(), True, False))
def test_mutation(): def test_mutation():
g = dgl.DGLGraph() g = dgl.DGLGraph()
......
...@@ -20,6 +20,7 @@ aten::CSRMatrix CSR1() { ...@@ -20,6 +20,7 @@ aten::CSRMatrix CSR1() {
csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, CTX); csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 5, 5}), sizeof(IDX)*8, CTX);
csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX); csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, CTX); csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 1, 4}), sizeof(IDX)*8, CTX);
csr.sorted = false;
return csr; return csr;
} }
...@@ -37,6 +38,7 @@ aten::CSRMatrix CSR2() { ...@@ -37,6 +38,7 @@ aten::CSRMatrix CSR2() {
csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX); csr.indptr = aten::VecToIdArray(std::vector<IDX>({0, 3, 4, 6, 6}), sizeof(IDX)*8, CTX);
csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX); csr.indices = aten::VecToIdArray(std::vector<IDX>({1, 2, 2, 0, 2, 3}), sizeof(IDX)*8, CTX);
csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX); csr.data = aten::VecToIdArray(std::vector<IDX>({0, 2, 5, 3, 1, 4}), sizeof(IDX)*8, CTX);
csr.sorted = false;
return csr; return csr;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment