Unverified Commit 5747542f authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Kernel] Migrate batching/unbatching on adjlist to CSR/COO (#1687)



* start

* coo csr union partition

* lint

* lint

* lint

* Add matrix->data transform

* update

* Fix window compile

* Add CSR support for DisjointPartition

* lint

* Fix

* Use IdArray Op

* Concat ready

* Fix and all pass

* resolve comments

* Add union COO C++ test

* Add C++ test for csr

* lint

* triger

* Update include

* Fix merge

* test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 55072b4e
...@@ -181,6 +181,19 @@ NDArray Repeat(NDArray array, IdArray repeats); ...@@ -181,6 +181,19 @@ NDArray Repeat(NDArray array, IdArray repeats);
*/ */
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!
* \brief concatenate the given id arrays to one array
*
* Example:
*
* Given two IdArrays [2, 3, 10, 0, 2] and [4, 10, 5]
* Return [2, 3, 10, 0, 2, 4, 10, 5]
*
* \param arrays The id arrays to concatenate.
* \return concatenated array.
*/
NDArray Concat(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/ /*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) { inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt; return arr->ndim == 1 && arr->dtype.code == kDLInt;
......
...@@ -371,6 +371,92 @@ COOMatrix COORowWiseTopk( ...@@ -371,6 +371,92 @@ COOMatrix COORowWiseTopk(
NDArray weight, NDArray weight,
bool ascending = false); bool ascending = false);
/*!
* \brief Union a list COOMatrix into one COOMatrix.
*
* Examples:
*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
*
* B = [[0, 0],
* [1, 0]]
*
* COOMatrix_A.num_rows : 3
* COOMatrix_A.num_cols : 3
* COOMatrix_B.num_rows : 2
* COOMatrix_B.num_cols : 2
*
* C = DisjointUnionCoo({A, B});
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0]]
* COOMatrix_C.num_rows : 5
* COOMatrix_C.num_cols : 5
*
* \param coos The input list of coo matrix.
* \param src_offset A list of integers recording src vertix id offset of each Matrix in coos
* \param src_offset A list of integers recording dst vertix id offset of each Matrix in coos
* \return The combined COOMatrix.
*/
COOMatrix DisjointUnionCoo(
const std::vector<COOMatrix>& coos);
/*!
* \brief Split a COOMatrix into multiple disjoin components.
*
* Examples:
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0],
* [0, 0, 0, 0, 1]]
* COOMatrix_C.num_rows : 6
* COOMatrix_C.num_cols : 5
*
* batch_size : 2
* edge_cumsum : [0, 4, 6]
* src_vertex_cumsum : [0, 3, 6]
* dst_vertex_cumsum : [0, 3, 5]
*
* ret = DisjointPartitionCooBySizes(C,
* batch_size,
* edge_cumsum,
* src_vertex_cumsum,
* dst_vertex_cumsum)
*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
* COOMatrix_A.num_rows : 3
* COOMatrix_A.num_cols : 3
*
* B = [[0, 0],
* [1, 0],
* [0, 1]]
* COOMatrix_B.num_rows : 3
* COOMatrix_B.num_cols : 2
*
* \param coo COOMatrix to split.
* \param batch_size Number of disjoin components (Sub COOMatrix)
* \param edge_cumsum Number of edges of each components
* \param src_vertex_cumsum Number of src vertices of each component.
* \param dst_vertex_cumsum Number of dst vertices of each component.
* \return A list of COOMatrixes representing each disjoint components.
*/
std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -366,6 +366,92 @@ COOMatrix CSRRowWiseTopk( ...@@ -366,6 +366,92 @@ COOMatrix CSRRowWiseTopk(
FloatArray weight, FloatArray weight,
bool ascending = false); bool ascending = false);
/*!
* \brief Union a list CSRMatrix into one CSRMatrix.
*
* Examples:
*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
*
* B = [[0, 0],
* [1, 0]]
*
* CSRMatrix_A.num_rows : 3
* CSRMatrix_A.num_cols : 3
* CSRMatrix_B.num_rows : 2
* CSRMatrix_B.num_cols : 2
*
* C = DisjointUnionCsr({A, B});
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0]]
* CSRMatrix_C.num_rows : 5
* CSRMatrix_C.num_cols : 5
*
* \param csrs The input list of csr matrix.
* \param src_offset A list of integers recording src vertix id offset of each Matrix in csrs
* \param src_offset A list of integers recording dst vertix id offset of each Matrix in csrs
* \return The combined CSRMatrix.
*/
CSRMatrix DisjointUnionCsr(
const std::vector<CSRMatrix>& csrs);
/*!
* \brief Split a CSRMatrix into multiple disjoin components.
*
* Examples:
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0],
* [0, 0, 0, 0, 1]]
* CSRMatrix_C.num_rows : 6
* CSRMatrix_C.num_cols : 5
*
* batch_size : 2
* edge_cumsum : [0, 4, 6]
* src_vertex_cumsum : [0, 3, 6]
* dst_vertex_cumsum : [0, 3, 5]
*
* ret = DisjointPartitionCsrBySizes(C,
* batch_size,
* edge_cumsum,
* src_vertex_cumsum,
* dst_vertex_cumsum)
*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
* CSRMatrix_A.num_rows : 3
* CSRMatrix_A.num_cols : 3
*
* B = [[0, 0],
* [1, 0],
* [0, 1]]
* CSRMatrix_B.num_rows : 3
* CSRMatrix_B.num_cols : 2
*
* \param csr CSRMatrix to split.
* \param batch_size Number of disjoin components (Sub CSRMatrix)
* \param edge_cumsum Number of edges of each components
* \param src_vertex_cumsum Number of src vertices of each component.
* \param dst_vertex_cumsum Number of dst vertices of each component.
* \return A list of CSRMatrixes representing each disjoint components.
*/
std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const CSRMatrix &csrs,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -55,6 +55,19 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) { ...@@ -55,6 +55,19 @@ inline std::string ToStringSparseFormat(SparseFormat sparse_format) {
return std::string("auto"); return std::string("auto");
} }
inline dgl_format_code_t SparseFormat2Code(SparseFormat sparse_format) {
if (sparse_format == SparseFormat::kCOO)
return 1;
else if (sparse_format == SparseFormat::kCSR)
return 2;
else if (sparse_format == SparseFormat::kCSC)
return 3;
else if (sparse_format == SparseFormat::kAny)
return 0;
else
return 4;
}
// Sparse matrix object that is exposed to python API. // Sparse matrix object that is exposed to python API.
struct SparseMatrix : public runtime::Object { struct SparseMatrix : public runtime::Object {
// Sparse format. // Sparse format.
......
...@@ -699,6 +699,9 @@ template <class IdType> ...@@ -699,6 +699,9 @@ template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*! /*!
* \brief Split a graph into multiple disjoin components. * \brief Split a graph into multiple disjoin components.
* *
...@@ -726,6 +729,13 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( ...@@ -726,6 +729,13 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
IdArray vertex_sizes, IdArray vertex_sizes,
IdArray edge_sizes); IdArray edge_sizes);
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph,
HeteroGraphPtr batched_graph,
IdArray vertex_sizes,
IdArray edge_sizes);
/*! /*!
* \brief Structure for pickle/unpickle. * \brief Structure for pickle/unpickle.
* *
...@@ -794,6 +804,15 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph); ...@@ -794,6 +804,15 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph);
*/ */
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states); HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states);
#define FORMAT_HAS_CSC(format) \
(format & (1<<2))
#define FORMAT_HAS_CSR(format) \
(format & (1<<1))
#define FORMAT_HAS_COO(format) \
(format & 1)
} // namespace dgl } // namespace dgl
#endif // DGL_BASE_HETEROGRAPH_H_ #endif // DGL_BASE_HETEROGRAPH_H_
...@@ -1131,7 +1131,7 @@ def disjoint_union(metagraph, graphs): ...@@ -1131,7 +1131,7 @@ def disjoint_union(metagraph, graphs):
HeteroGraphIndex HeteroGraphIndex
Batched Heterograph. Batched Heterograph.
""" """
return _CAPI_DGLHeteroDisjointUnion(metagraph, graphs) return _CAPI_DGLHeteroDisjointUnion_v2(metagraph, graphs)
def disjoint_partition(graph, bnn_all_types, bne_all_types): def disjoint_partition(graph, bnn_all_types, bne_all_types):
"""Partition the graph disjointly. """Partition the graph disjointly.
...@@ -1152,7 +1152,7 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types): ...@@ -1152,7 +1152,7 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types):
""" """
bnn_all_types = utils.toindex(list(itertools.chain.from_iterable(bnn_all_types))) bnn_all_types = utils.toindex(list(itertools.chain.from_iterable(bnn_all_types)))
bne_all_types = utils.toindex(list(itertools.chain.from_iterable(bne_all_types))) bne_all_types = utils.toindex(list(itertools.chain.from_iterable(bne_all_types)))
return _CAPI_DGLHeteroDisjointPartitionBySizes( return _CAPI_DGLHeteroDisjointPartitionBySizes_v2(
graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor()) graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor())
################################################################# #################################################################
......
...@@ -173,6 +173,42 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -173,6 +173,42 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return ret; return ret;
} }
NDArray Concat(const std::vector<IdArray>& arrays) {
CHECK(arrays.size() > 1) << "Number of arrays should larger than 1";
IdArray ret;
int64_t len = 0, offset = 0;
for (size_t i = 0; i < arrays.size(); ++i) {
len += arrays[i]->shape[0];
CHECK_SAME_DTYPE(arrays[0], arrays[i]);
CHECK_SAME_CONTEXT(arrays[0], arrays[i]);
}
NDArray ret_arr = NDArray::Empty({len},
arrays[0]->dtype,
arrays[0]->ctx);
auto device = runtime::DeviceAPI::Get(arrays[0]->ctx);
for (size_t i = 0; i < arrays.size(); ++i) {
ATEN_DTYPE_SWITCH(arrays[i]->dtype, DType, "array", {
device->CopyDataFromTo(
static_cast<DType*>(arrays[i]->data),
0,
static_cast<DType*>(ret_arr->data),
offset,
arrays[i]->shape[0] * sizeof(DType),
arrays[i]->ctx,
ret_arr->ctx,
arrays[i]->dtype,
nullptr);
offset += arrays[i]->shape[0] * sizeof(DType);
});
}
return ret_arr;
}
template<typename ValueType> template<typename ValueType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) { std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, ValueType pad_value) {
std::tuple<NDArray, IdArray, IdArray> ret; std::tuple<NDArray, IdArray, IdArray> ret;
...@@ -692,6 +728,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) { ...@@ -692,6 +728,7 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
}); });
return ret; return ret;
} }
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source, IdArray source,
const bool has_reverse_edge, const bool has_reverse_edge,
...@@ -716,6 +753,7 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -716,6 +753,7 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return ret; return ret;
} }
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat") DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLSparseMatrixGetFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
......
...@@ -55,6 +55,9 @@ NDArray Repeat(NDArray array, IdArray repeats); ...@@ -55,6 +55,9 @@ NDArray Repeat(NDArray array, IdArray repeats);
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays); IdArray Relabel_(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename IdType>
NDArray Concat(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename DType> template <DLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value); std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
...@@ -146,6 +149,7 @@ template <DLDeviceType XPU, typename IdType, typename DType> ...@@ -146,6 +149,7 @@ template <DLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk( COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending); CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
...@@ -234,6 +238,8 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr, ...@@ -234,6 +238,8 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
const bool has_nontree_edge, const bool has_nontree_edge,
const bool return_labels); const bool return_labels);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_union_partition.cc
* \brief COO union and partition
*/
#include <dgl/array.h>
#include <vector>
namespace dgl {
namespace aten {
///////////////////////// COO Based Operations/////////////////////////
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
CHECK(coos.size() > 1) <<
"The length of input COOMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0;
int64_t edge_data_offset = 0;
bool has_data = false;
bool row_sorted = true;
bool col_sorted = true;
// check if data index array
for (size_t i = 0; i < coos.size(); ++i) {
CHECK_SAME_DTYPE(coos[0].row, coos[i].row);
CHECK_SAME_CONTEXT(coos[0].row, coos[i].row);
has_data |= COOHasData(coos[i]);
}
std::vector<IdArray> res_src;
std::vector<IdArray> res_dst;
std::vector<IdArray> res_data;
res_src.resize(coos.size());
res_dst.resize(coos.size());
for (size_t i = 0; i < coos.size(); ++i) {
const aten::COOMatrix &coo = coos[i];
row_sorted &= coo.row_sorted;
col_sorted &= coo.col_sorted;
IdArray edges_src = coo.row + src_offset;
IdArray edges_dst = coo.col + dst_offset;
res_src[i] = edges_src;
res_dst[i] = edges_dst;
src_offset += coo.num_rows;
dst_offset += coo.num_cols;
// any one of input coo has data index array
if (has_data) {
IdArray edges_data;
if (COOHasData(coo) == false) {
edges_data = Range(edge_data_offset,
edge_data_offset + coo.row->shape[0],
coo.row->dtype.bits,
coo.row->ctx);
} else {
edges_data = coo.data + edge_data_offset;
}
res_data.push_back(edges_data);
edge_data_offset += coo.row->shape[0];
}
}
IdArray result_src = Concat(res_src);
IdArray result_dst = Concat(res_dst);
IdArray result_data = has_data ? Concat(res_data) : NullArray();
return COOMatrix(
src_offset, dst_offset,
result_src,
result_dst,
result_data,
row_sorted,
col_sorted);
}
std::vector<COOMatrix> DisjointPartitionCooBySizes(
const COOMatrix &coo,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum) {
CHECK_EQ(edge_cumsum.size(), batch_size + 1);
CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
std::vector<COOMatrix> ret;
ret.resize(batch_size);
for (size_t g = 0; g < batch_size; ++g) {
IdArray result_src = IndexSelect(coo.row,
edge_cumsum[g],
edge_cumsum[g + 1]) - src_vertex_cumsum[g];
IdArray result_dst = IndexSelect(coo.col,
edge_cumsum[g],
edge_cumsum[g + 1]) - dst_vertex_cumsum[g];
IdArray result_data = NullArray();
// has data index array
if (COOHasData(coo)) {
result_data = IndexSelect(coo.data,
edge_cumsum[g],
edge_cumsum[g + 1]) - edge_cumsum[g];
}
COOMatrix sub_coo = COOMatrix(
src_vertex_cumsum[g+1]-src_vertex_cumsum[g],
dst_vertex_cumsum[g+1]-dst_vertex_cumsum[g],
result_src,
result_dst,
result_data,
coo.row_sorted,
coo.col_sorted);
ret[g] = sub_coo;
}
return ret;
}
///////////////////////// CSR Based Operations/////////////////////////
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
CHECK(csrs.size() > 1) <<
"The length of input CSRMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0;
int64_t indices_offset = 0;
bool has_data = false;
bool sorted = true;
// check if data index array
for (size_t i = 0; i < csrs.size(); ++i) {
CHECK_SAME_DTYPE(csrs[0].indptr, csrs[i].indptr);
CHECK_SAME_CONTEXT(csrs[0].indices, csrs[i].indices);
has_data |= CSRHasData(csrs[i]);
}
std::vector<IdArray> res_indptr;
std::vector<IdArray> res_indices;
std::vector<IdArray> res_data;
res_indptr.resize(csrs.size());
res_indices.resize(csrs.size());
for (size_t i = 0; i < csrs.size(); ++i) {
const aten::CSRMatrix &csr = csrs[i];
sorted &= csr.sorted;
IdArray indptr = csr.indptr + indices_offset;
IdArray indices = csr.indices + dst_offset;
if (i > 0)
indptr = IndexSelect(indptr,
1,
indptr->shape[0]);
res_indptr[i] = indptr;
res_indices[i] = indices;
src_offset += csr.num_rows;
dst_offset += csr.num_cols;
// any one of input csr has data index array
if (has_data) {
IdArray edges_data;
if (CSRHasData(csr) == false) {
edges_data = Range(indices_offset,
indices_offset + csr.indices->shape[0],
csr.indices->dtype.bits,
csr.indices->ctx);
} else {
edges_data = csr.data + indices_offset;
}
res_data.push_back(edges_data);
indices_offset += csr.indices->shape[0];
}
}
IdArray result_indptr = Concat(res_indptr);
IdArray result_indices = Concat(res_indices);
IdArray result_data = has_data ? Concat(res_data) : NullArray();
return CSRMatrix(
src_offset, dst_offset,
result_indptr,
result_indices,
result_data,
sorted);
}
std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const CSRMatrix &csr,
const uint64_t batch_size,
const std::vector<uint64_t> &edge_cumsum,
const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum) {
CHECK_EQ(edge_cumsum.size(), batch_size + 1);
CHECK_EQ(src_vertex_cumsum.size(), batch_size + 1);
CHECK_EQ(dst_vertex_cumsum.size(), batch_size + 1);
std::vector<CSRMatrix> ret;
ret.resize(batch_size);
for (size_t g = 0; g < batch_size; ++g) {
uint64_t num_src = src_vertex_cumsum[g+1]-src_vertex_cumsum[g];
IdArray result_indptr;
if (g == 0) {
result_indptr = IndexSelect(csr.indptr,
0,
src_vertex_cumsum[1] + 1) - edge_cumsum[0];
} else {
result_indptr = IndexSelect(csr.indptr,
src_vertex_cumsum[g],
src_vertex_cumsum[g+1] + 1) - edge_cumsum[g];
}
IdArray result_indices = IndexSelect(csr.indices,
edge_cumsum[g],
edge_cumsum[g+1]) - dst_vertex_cumsum[g];
IdArray result_data = NullArray();
// has data index array
if (CSRHasData(csr)) {
result_data = IndexSelect(csr.data,
edge_cumsum[g],
edge_cumsum[g+1]) - edge_cumsum[g];
}
CSRMatrix sub_csr = CSRMatrix(
num_src,
dst_vertex_cumsum[g+1]-dst_vertex_cumsum[g],
result_indptr,
result_indices,
result_data,
csr.sorted);
ret[g] = sub_csr;
}
return ret;
}
} // namespace aten
} // namespace dgl
...@@ -433,6 +433,45 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -433,6 +433,45 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
*rv = HeteroGraphRef(hg_new); *rv = HeteroGraphRef(hg_new);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
<< "Expect graph list has at least one graph";
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
const DLContext ctx = component_graphs[0]->Context();
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to batch have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
CHECK_EQ(component->Context(), ctx)
<< "Expect graphs to batch have the same context" << ctx
<< "), but got " << component->Context();
}
auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
std::vector<HeteroGraphPtr> ret;
ret = DisjointPartitionHeteroBySizes2(hg->meta_graph(), hg.sptr(),
vertex_sizes, edge_sizes);
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
......
...@@ -8,6 +8,209 @@ using namespace dgl::runtime; ...@@ -8,6 +8,209 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
uint64_t src_offset = 0, dst_offset = 0;
HeteroGraphPtr rgptr = nullptr;
// ALL = CSC | CSR | COO
dgl_format_code_t format = (1 << (SparseFormat2Code(SparseFormat::kCOO)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSR)-1)) |
(1 << (SparseFormat2Code(SparseFormat::kCSC)-1));
// do some preprocess
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
const std::string restrict_format = cg->GetRelationGraph(etype)->GetRestrictFormat();
const SparseFormat curr_format = ParseSparseFormat(restrict_format);
if (curr_format == SparseFormat::kCOO ||
curr_format == SparseFormat::kCSR ||
curr_format == SparseFormat::kCSC)
format &=(1 << (SparseFormat2Code(curr_format)-1));
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}
CHECK_GT(format, 0) << "The conjunction of restrict_format of the relation graphs under " <<
etype << "should not be None.";
// prefer COO
if (FORMAT_HAS_COO(format)) {
std::vector<aten::COOMatrix> coos;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo);
}
aten::COOMatrix res = aten::DisjointUnionCoo(coos);
rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
} else if (FORMAT_HAS_CSR(format)) {
std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr);
}
aten::CSRMatrix res = aten::DisjointUnionCsr(csrs);
rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
} else if (FORMAT_HAS_CSC(format)) {
std::vector<aten::CSRMatrix> cscs;
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc);
}
aten::CSRMatrix res = aten::DisjointUnionCsr(cscs);
rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res,
SparseFormat::kAny);
}
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
// Sanity check for vertex sizes
CHECK_EQ(vertex_sizes->dtype.bits, 64) << "dtype of vertex_sizes should be int64";
CHECK_EQ(edge_sizes->dtype.bits, 64) << "dtype of edge_sizes should be int64";
const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum
std::vector<std::vector<uint64_t>> vertex_cumsum;
vertex_cumsum.resize(num_vertex_types);
// Loop over all vertex types
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_cumsum[vtype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
}
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype;
}
// Sanity check for edge sizes
const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);
const uint64_t num_edge_types = meta_graph->NumEdges();
// Map edge type to the corresponding edge cum sum
std::vector<std::vector<uint64_t>> edge_cumsum;
edge_cumsum.resize(num_edge_types);
// Loop over all edge types
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
edge_cumsum[etype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
}
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype;
}
// Construct relation graphs for unbatched graphs
std::vector<std::vector<HeteroGraphPtr>> rel_graphs;
rel_graphs.resize(batch_size);
// Loop over all edge types
auto format = batched_graph->GetRelationGraph(0)->GetFormatInUse();
auto restrict_format = batched_graph->GetRelationGraph(0)->GetRestrictFormat();
if (FORMAT_HAS_COO(format)) {
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
auto res = aten::DisjointPartitionCooBySizes(coo,
batch_size,
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, res[g],
ParseSparseFormat(restrict_format));
rel_graphs[g].push_back(rgptr);
}
}
} else if (FORMAT_HAS_CSR(format)) {
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csr,
batch_size,
edge_cumsum[etype],
vertex_cumsum[src_vtype],
vertex_cumsum[dst_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSR(
(src_vtype == dst_vtype) ? 1 : 2, res[g],
ParseSparseFormat(restrict_format));
rel_graphs[g].push_back(rgptr);
}
}
} else if (FORMAT_HAS_CSC(format)) {
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
auto res = aten::DisjointPartitionCsrBySizes(csc,
batch_size,
edge_cumsum[etype],
vertex_cumsum[dst_vtype],
vertex_cumsum[src_vtype]);
for (uint64_t g = 0; g < batch_size; ++g) {
HeteroGraphPtr rgptr = UnitGraph::CreateFromCSC(
(src_vtype == dst_vtype) ? 1 : 2, res[g],
SparseFormat::kAny);
rel_graphs[g].push_back(rgptr);
}
}
}
std::vector<HeteroGraphPtr> rst;
std::vector<int64_t> num_nodes_per_type(num_vertex_types);
for (uint64_t g = 0; g < batch_size; ++g) {
for (uint64_t i = 0; i < num_vertex_types; ++i)
num_nodes_per_type[i] = vertex_sizes_data[i * batch_size + g];
rst.push_back(CreateHeteroGraph(meta_graph, rel_graphs[g], num_nodes_per_type));
}
return rst;
}
template <class IdType> template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph( HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
......
...@@ -19,10 +19,11 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N ...@@ -19,10 +19,11 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
for ety in g1.canonical_etypes: for ety in g1.canonical_etypes:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety) assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
src1, dst1 = g1.all_edges(etype=ety) src1, dst1, eid1 = g1.all_edges(etype=ety, form='all')
src2, dst2 = g2.all_edges(etype=ety) src2, dst2, eid2 = g2.all_edges(etype=ety, form='all')
assert F.allclose(src1, src2) assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2) assert F.allclose(dst1, dst2)
assert F.allclose(eid1, eid2)
if node_attrs is not None: if node_attrs is not None:
for nty in node_attrs.keys(): for nty in node_attrs.keys():
...@@ -88,15 +89,121 @@ def test_batching_hetero_topology(index_dtype): ...@@ -88,15 +89,121 @@ def test_batching_hetero_topology(index_dtype):
src, dst = bg.all_edges(etype=('user', 'follows', 'developer')) src, dst = bg.all_edges(etype=('user', 'follows', 'developer'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5] assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 4, 5] assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
src, dst = bg.all_edges(etype='plays') src, dst, eid = bg.all_edges(etype='plays', form='all')
assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6] assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3] assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]
# Test unbatching graphs # Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg) g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(g1, g3) check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4) check_equivalence_between_heterographs(g2, g4)
"""Test batching two DGLHeteroGraphs with csr format"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
}, index_dtype=index_dtype, restrict_format='csr')
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
}, index_dtype=index_dtype, restrict_format='csr')
bg = dgl.batch_hetero([g1, g2])
# Test number of nodes
for ntype in bg.ntypes:
assert bg.batch_num_nodes(ntype) == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
assert bg.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))
# Test number of edges
assert bg.batch_num_edges('plays') == [
g1.number_of_edges('plays'), g2.number_of_edges('plays')]
assert bg.number_of_edges('plays') == (
g1.number_of_edges('plays') + g2.number_of_edges('plays'))
for etype in bg.canonical_etypes:
assert bg.batch_num_edges(etype) == [
g1.number_of_edges(etype), g2.number_of_edges(etype)]
assert bg.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype))
# Test relabeled nodes
for ntype in bg.ntypes:
assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))
# Test relabeled edges
src, dst = bg.all_edges(etype=('user', 'follows', 'user'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
src, dst = bg.all_edges(etype=('user', 'follows', 'developer'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
src, dst, eid = bg.all_edges(etype='plays', form='all')
assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]
# Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4)
"""Test batching two DGLHeteroGraphs with csc"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
}, index_dtype=index_dtype, restrict_format='csc')
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
}, index_dtype=index_dtype, restrict_format='csc')
bg = dgl.batch_hetero([g1, g2])
# Test number of nodes
for ntype in bg.ntypes:
assert bg.batch_num_nodes(ntype) == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
assert bg.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))
# Test number of edges
assert bg.batch_num_edges('plays') == [
g1.number_of_edges('plays'), g2.number_of_edges('plays')]
assert bg.number_of_edges('plays') == (
g1.number_of_edges('plays') + g2.number_of_edges('plays'))
for etype in bg.canonical_etypes:
assert bg.batch_num_edges(etype) == [
g1.number_of_edges(etype), g2.number_of_edges(etype)]
assert bg.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype))
# Test relabeled nodes
for ntype in bg.ntypes:
assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))
# Test relabeled edges
src, dst = bg.all_edges(etype=('user', 'follows', 'user'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
src, dst = bg.all_edges(etype=('user', 'follows', 'developer'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
src, dst, eid = bg.all_edges(etype='plays', form='all')
assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]
# Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4)
@parametrize_dtype @parametrize_dtype
def test_batching_hetero_and_batched_hetero_topology(index_dtype): def test_batching_hetero_and_batched_hetero_topology(index_dtype):
...@@ -296,8 +403,8 @@ def test_to_device(index_dtype): ...@@ -296,8 +403,8 @@ def test_to_device(index_dtype):
assert bg.batch_num_edges('plays') == bg1.batch_num_edges('plays') assert bg.batch_num_edges('plays') == bg1.batch_num_edges('plays')
if __name__ == '__main__': if __name__ == '__main__':
test_batching_hetero_topology() test_batching_hetero_topology('int32')
test_batching_hetero_and_batched_hetero_topology() test_batching_hetero_and_batched_hetero_topology('int32')
test_batched_features() test_batched_features('int32')
test_batching_with_zero_nodes_edges() test_batching_with_zero_nodes_edges('int32')
# test_to_device() # test_to_device()
...@@ -229,9 +229,11 @@ void _TestRelabel_() { ...@@ -229,9 +229,11 @@ void _TestRelabel_() {
IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX); IdArray a = aten::VecToIdArray(std::vector<IDX>({0, 20, 10}), sizeof(IDX)*8, CTX);
IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, CTX); IdArray b = aten::VecToIdArray(std::vector<IDX>({20, 5, 6}), sizeof(IDX)*8, CTX);
IdArray c = aten::Relabel_({a, b}); IdArray c = aten::Relabel_({a, b});
IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX); IdArray ta = aten::VecToIdArray(std::vector<IDX>({0, 1, 2}), sizeof(IDX)*8, CTX);
IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, CTX); IdArray tb = aten::VecToIdArray(std::vector<IDX>({1, 3, 4}), sizeof(IDX)*8, CTX);
IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, CTX); IdArray tc = aten::VecToIdArray(std::vector<IDX>({0, 20, 10, 5, 6}), sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(a, ta)); ASSERT_TRUE(ArrayEQ<IDX>(a, ta));
ASSERT_TRUE(ArrayEQ<IDX>(b, tb)); ASSERT_TRUE(ArrayEQ<IDX>(b, tb));
ASSERT_TRUE(ArrayEQ<IDX>(c, tc)); ASSERT_TRUE(ArrayEQ<IDX>(c, tc));
...@@ -242,6 +244,337 @@ TEST(ArrayTest, TestRelabel_) { ...@@ -242,6 +244,337 @@ TEST(ArrayTest, TestRelabel_) {
_TestRelabel_<int64_t>(); _TestRelabel_<int64_t>();
} }
template <typename IDX>
void _TestConcat(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({1, 2, 3}), sizeof(IDX)*8, CTX);
IdArray b = aten::VecToIdArray(std::vector<IDX>({4, 5, 6}), sizeof(IDX)*8, CTX);
IdArray tc = aten::VecToIdArray(std::vector<IDX>({1, 2, 3, 4, 5, 6}), sizeof(IDX)*8, CTX);
IdArray c = aten::Concat(std::vector<IdArray>{a, b});
ASSERT_TRUE(ArrayEQ<IDX>(c, tc));
IdArray d = aten::Concat(std::vector<IdArray>{a, b, c});
IdArray td = aten::VecToIdArray(std::vector<IDX>({1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}),
sizeof(IDX)*8, CTX);
ASSERT_TRUE(ArrayEQ<IDX>(d, td));
}
TEST(ArrayTest, TestConcat) {
_TestConcat<int32_t>(CPU);
_TestConcat<int64_t>(CPU);
_TestConcat<float>(CPU);
_TestConcat<double>(CPU);
#ifdef DGL_USE_CUDA
_TestConcat<int32_t>(GPU);
_TestConcat<int64_t>(GPU);
_TestConcat<float>(GPU);
_TestConcat<double>(GPU);
#endif
}
template <typename IdType>
void _TestDisjointUnionPartitionCoo(DLContext ctx) {
/*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
*
* B = [[1, 1, 0],
* [0, 1, 0]]
*
* C = [[1]]
*
* AB = [[0, 0, 1, 0, 0, 0],
* [1, 0, 1, 0, 0, 0],
* [0, 1, 0, 0, 0, 0],
* [0, 0, 0, 1, 1, 0],
* [0, 0, 0, 0, 1, 0]]
*
* ABC = [[0, 0, 1, 0, 0, 0, 0],
* [1, 0, 1, 0, 0, 0, 0],
* [0, 1, 0, 0, 0, 0, 0],
* [0, 0, 0, 1, 1, 0, 0],
* [0, 0, 0, 0, 1, 0, 0],
* [0, 0, 0, 0, 0, 0, 1]]
*/
IdArray a_row =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1, 2}), sizeof(IdType)*8, CTX);
IdArray a_col =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1}), sizeof(IdType)*8, CTX);
IdArray b_row =
aten::VecToIdArray(std::vector<IdType>({0, 0, 1}), sizeof(IdType)*8, CTX);
IdArray b_col =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1}), sizeof(IdType)*8, CTX);
IdArray b_data =
aten::VecToIdArray(std::vector<IdType>({2, 0, 1}), sizeof(IdType)*8, CTX);
IdArray c_row =
aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType)*8, CTX);
IdArray c_col =
aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType)*8, CTX);
IdArray ab_row =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1, 2, 3, 3, 4}), sizeof(IdType)*8, CTX);
IdArray ab_col =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 3, 4, 4}), sizeof(IdType)*8, CTX);
IdArray ab_data =
aten::VecToIdArray(std::vector<IdType>({0, 1, 2, 3, 6, 4, 5}), sizeof(IdType)*8, CTX);
IdArray abc_row =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1, 2, 3, 3, 4, 5}), sizeof(IdType)*8, CTX);
IdArray abc_col =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 3, 4, 4, 6}), sizeof(IdType)*8, CTX);
IdArray abc_data =
aten::VecToIdArray(std::vector<IdType>({0, 1, 2, 3, 6, 4, 5, 7}), sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_a = aten::COOMatrix(
3,
3,
a_row,
a_col,
aten::NullArray(),
true,
false);
const aten::COOMatrix &coo_b = aten::COOMatrix(
2,
3,
b_row,
b_col,
b_data,
true,
true);
const aten::COOMatrix &coo_c = aten::COOMatrix(
1,
1,
c_row,
c_col,
aten::NullArray(),
true,
true);
const std::vector<aten::COOMatrix> coos_ab({coo_a, coo_b});
const aten::COOMatrix &coo_ab = aten::DisjointUnionCoo(coos_ab);
ASSERT_EQ(coo_ab.num_rows, 5);
ASSERT_EQ(coo_ab.num_cols, 6);
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.row, ab_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.col, ab_col));
ASSERT_TRUE(ArrayEQ<IdType>(coo_ab.data, ab_data));
ASSERT_TRUE(coo_ab.row_sorted);
ASSERT_FALSE(coo_ab.col_sorted);
const std::vector<uint64_t> edge_cumsum({0, 4, 7});
const std::vector<uint64_t> src_vertex_cumsum({0, 3, 5});
const std::vector<uint64_t> dst_vertex_cumsum({0, 3, 6});
const std::vector<aten::COOMatrix> &p_coos = aten::DisjointPartitionCooBySizes(
coo_ab,
2,
edge_cumsum,
src_vertex_cumsum,
dst_vertex_cumsum);
ASSERT_EQ(p_coos[0].num_rows, coo_a.num_rows);
ASSERT_EQ(p_coos[0].num_cols, coo_a.num_cols);
ASSERT_EQ(p_coos[1].num_rows, coo_b.num_rows);
ASSERT_EQ(p_coos[1].num_cols, coo_b.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(p_coos[0].row, coo_a.row));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos[0].col, coo_a.col));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].row, coo_b.row));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].col, coo_b.col));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos[1].data, coo_b.data));
ASSERT_TRUE(p_coos[0].row_sorted);
ASSERT_FALSE(p_coos[0].col_sorted);
ASSERT_TRUE(p_coos[1].row_sorted);
ASSERT_FALSE(p_coos[1].col_sorted);
const std::vector<aten::COOMatrix> coos_abc({coo_a, coo_b, coo_c});
const aten::COOMatrix &coo_abc = aten::DisjointUnionCoo(coos_abc);
ASSERT_EQ(coo_abc.num_rows, 6);
ASSERT_EQ(coo_abc.num_cols, 7);
ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.row, abc_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.col, abc_col));
ASSERT_TRUE(ArrayEQ<IdType>(coo_abc.data, abc_data));
ASSERT_TRUE(coo_abc.row_sorted);
ASSERT_FALSE(coo_abc.col_sorted);
const std::vector<uint64_t> edge_cumsum_abc({0, 4, 7, 8});
const std::vector<uint64_t> src_vertex_cumsum_abc({0, 3, 5, 6});
const std::vector<uint64_t> dst_vertex_cumsum_abc({0, 3, 6, 7});
const std::vector<aten::COOMatrix> &p_coos_abc = aten::DisjointPartitionCooBySizes(
coo_abc,
3,
edge_cumsum_abc,
src_vertex_cumsum_abc,
dst_vertex_cumsum_abc);
ASSERT_EQ(p_coos_abc[0].num_rows, coo_a.num_rows);
ASSERT_EQ(p_coos_abc[0].num_cols, coo_a.num_cols);
ASSERT_EQ(p_coos_abc[1].num_rows, coo_b.num_rows);
ASSERT_EQ(p_coos_abc[1].num_cols, coo_b.num_cols);
ASSERT_EQ(p_coos_abc[2].num_rows, coo_c.num_rows);
ASSERT_EQ(p_coos_abc[2].num_cols, coo_c.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[0].row, coo_a.row));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[0].col, coo_a.col));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].row, coo_b.row));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].col, coo_b.col));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[1].data, coo_b.data));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[2].row, coo_c.row));
ASSERT_TRUE(ArrayEQ<IdType>(p_coos_abc[2].col, coo_c.col));
ASSERT_TRUE(p_coos_abc[0].row_sorted);
ASSERT_FALSE(p_coos_abc[0].col_sorted);
ASSERT_TRUE(p_coos_abc[1].row_sorted);
ASSERT_FALSE(p_coos_abc[1].col_sorted);
ASSERT_TRUE(p_coos_abc[2].row_sorted);
ASSERT_FALSE(p_coos_abc[2].col_sorted);
}
TEST(DisjointUnionTest, TestDisjointUnionPartitionCoo) {
_TestDisjointUnionPartitionCoo<int32_t>(CPU);
_TestDisjointUnionPartitionCoo<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestDisjointUnionPartitionCoo<int32_t>(GPU);
_TestDisjointUnionPartitionCoo<int64_t>(GPU);
#endif
}
template <typename IdType>
void _TestDisjointUnionPartitionCsr(DLContext ctx) {
/*
* A = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0]]
*
* B = [[1, 1, 0],
* [0, 1, 0]]
*
* C = [[1]]
*
* BC = [[1, 1, 0, 0],
* [0, 1, 0, 0],
* [0, 0, 0, 1]],
*
* ABC = [[0, 0, 1, 0, 0, 0, 0],
* [1, 0, 1, 0, 0, 0, 0],
* [0, 1, 0, 0, 0, 0, 0],
* [0, 0, 0, 1, 1, 0, 0],
* [0, 0, 0, 0, 1, 0, 0],
* [0, 0, 0, 0, 0, 0, 1]]
*/
IdArray a_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 4}), sizeof(IdType)*8, CTX);
IdArray a_indices =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1}), sizeof(IdType)*8, CTX);
IdArray b_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 2, 3}), sizeof(IdType)*8, CTX);
IdArray b_indices =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1}), sizeof(IdType)*8, CTX);
IdArray b_data =
aten::VecToIdArray(std::vector<IdType>({2, 0, 1}), sizeof(IdType)*8, CTX);
IdArray c_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType)*8, CTX);
IdArray c_indices =
aten::VecToIdArray(std::vector<IdType>({0}), sizeof(IdType)*8, CTX);
IdArray bc_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 2, 3, 4}), sizeof(IdType)*8, CTX);
IdArray bc_indices =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1, 3}), sizeof(IdType)*8, CTX);
IdArray bc_data =
aten::VecToIdArray(std::vector<IdType>({2, 0, 1, 3}), sizeof(IdType)*8, CTX);
IdArray abc_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 4, 6, 7, 8}), sizeof(IdType)*8, CTX);
IdArray abc_indices =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 3, 4, 4, 6}), sizeof(IdType)*8, CTX);
IdArray abc_data =
aten::VecToIdArray(std::vector<IdType>({0, 1, 2, 3, 6, 4, 5, 7}), sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_a = aten::CSRMatrix(
3,
3,
a_indptr,
a_indices,
aten::NullArray(),
false);
const aten::CSRMatrix &csr_b = aten::CSRMatrix(
2,
3,
b_indptr,
b_indices,
b_data,
true);
const aten::CSRMatrix &csr_c = aten::CSRMatrix(
1,
1,
c_indptr,
c_indices,
aten::NullArray(),
true);
const std::vector<aten::CSRMatrix> csrs_bc({csr_b, csr_c});
const aten::CSRMatrix &csr_bc = aten::DisjointUnionCsr(csrs_bc);
ASSERT_EQ(csr_bc.num_rows, 3);
ASSERT_EQ(csr_bc.num_cols, 4);
ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.indptr, bc_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.indices, bc_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_bc.data, bc_data));
ASSERT_TRUE(csr_bc.sorted);
const std::vector<uint64_t> edge_cumsum({0, 3, 4});
const std::vector<uint64_t> src_vertex_cumsum({0, 2, 3});
const std::vector<uint64_t> dst_vertex_cumsum({0, 3, 4});
const std::vector<aten::CSRMatrix> &p_csrs = aten::DisjointPartitionCsrBySizes(
csr_bc,
2,
edge_cumsum,
src_vertex_cumsum,
dst_vertex_cumsum);
ASSERT_EQ(p_csrs[0].num_rows, csr_b.num_rows);
ASSERT_EQ(p_csrs[0].num_cols, csr_b.num_cols);
ASSERT_EQ(p_csrs[1].num_rows, csr_c.num_rows);
ASSERT_EQ(p_csrs[1].num_cols, csr_c.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].indptr, csr_b.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].indices, csr_b.indices));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[0].data, csr_b.data));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[1].indptr, csr_c.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs[1].indices, csr_c.indices));
ASSERT_TRUE(p_csrs[0].sorted);
ASSERT_TRUE(p_csrs[1].sorted);
const std::vector<aten::CSRMatrix> csrs_abc({csr_a, csr_b, csr_c});
const aten::CSRMatrix &csr_abc = aten::DisjointUnionCsr(csrs_abc);
ASSERT_EQ(csr_abc.num_rows, 6);
ASSERT_EQ(csr_abc.num_cols, 7);
ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.indptr, abc_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.indices, abc_indices));
ASSERT_TRUE(ArrayEQ<IdType>(csr_abc.data, abc_data));
ASSERT_FALSE(csr_abc.sorted);
const std::vector<uint64_t> edge_cumsum_abc({0, 4, 7, 8});
const std::vector<uint64_t> src_vertex_cumsum_abc({0, 3, 5, 6});
const std::vector<uint64_t> dst_vertex_cumsum_abc({0, 3, 6, 7});
const std::vector<aten::CSRMatrix> &p_csrs_abc = aten::DisjointPartitionCsrBySizes(
csr_abc,
3,
edge_cumsum_abc,
src_vertex_cumsum_abc,
dst_vertex_cumsum_abc);
ASSERT_EQ(p_csrs_abc[0].num_rows, csr_a.num_rows);
ASSERT_EQ(p_csrs_abc[0].num_cols, csr_a.num_cols);
ASSERT_EQ(p_csrs_abc[1].num_rows, csr_b.num_rows);
ASSERT_EQ(p_csrs_abc[1].num_cols, csr_b.num_cols);
ASSERT_EQ(p_csrs_abc[2].num_rows, csr_c.num_rows);
ASSERT_EQ(p_csrs_abc[2].num_cols, csr_c.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[0].indptr, csr_a.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[0].indices, csr_a.indices));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].indptr, csr_b.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].indices, csr_b.indices));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[1].data, csr_b.data));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[2].indptr, csr_c.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(p_csrs_abc[2].indices, csr_c.indices));
ASSERT_FALSE(p_csrs_abc[0].sorted);
ASSERT_FALSE(p_csrs_abc[1].sorted);
ASSERT_FALSE(p_csrs_abc[2].sorted);
}
TEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) {
_TestDisjointUnionPartitionCsr<int32_t>(CPU);
_TestDisjointUnionPartitionCsr<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestDisjointUnionPartitionCsr<int32_t>(GPU);
_TestDisjointUnionPartitionCsr<int64_t>(GPU);
#endif
}
template <typename IDX> template <typename IDX>
void _TestCumSum(DLContext ctx) { void _TestCumSum(DLContext ctx) {
IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}), IdArray a = aten::VecToIdArray(std::vector<IDX>({8, 6, 7, 5, 3, 0, 9}),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment