"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "269983dbcd221acb64d03817fb25f0e27788bbaf"
Unverified Commit 5be937a7 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Kernel] Slicing Batched Graphs (#2965)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Update

* Update

* Add files via upload

* Add files via upload

* Update

* Lint

* Add files via upload

* Lint

* Update

* Update

* Update

* Update

* Update

* Lint Fix

* Lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-12-161.us-west-2.compute.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent ba154924
...@@ -90,6 +90,7 @@ operators for computing graph-level representation for both single and batched g ...@@ -90,6 +90,7 @@ operators for computing graph-level representation for both single and batched g
batch batch
unbatch unbatch
slice_batch
readout_nodes readout_nodes
readout_edges readout_edges
sum_nodes sum_nodes
......
...@@ -545,6 +545,47 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -545,6 +545,47 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum); const std::vector<uint64_t> &dst_vertex_cumsum);
/*!
* \brief Slice a contiguous chunk from a COOMatrix
*
* Examples:
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0],
* [0, 0, 0, 0, 1]]
* COOMatrix_C.num_rows : 6
* COOMatrix_C.num_cols : 5
*
* edge_range : [4, 6]
* src_vertex_range : [3, 6]
* dst_vertex_range : [3, 5]
*
* ret = COOSliceContiguousChunk(C,
* edge_range,
* src_vertex_range,
* dst_vertex_range)
*
* ret = [[0, 0],
* [1, 0],
* [0, 1]]
* COOMatrix_ret.num_rows : 3
* COOMatrix_ret.num_cols : 2
*
* \param coo COOMatrix to slice.
* \param edge_range ID range of the edges in the chunk
* \param src_vertex_range ID range of the src vertices in the chunk.
* \param dst_vertex_range ID range of the dst vertices in the chunk.
* \return COOMatrix representing the chunk.
*/
COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range);
/*! /*!
* \brief Create a LineGraph of input coo * \brief Create a LineGraph of input coo
* *
......
...@@ -554,7 +554,7 @@ CSRMatrix DisjointUnionCsr( ...@@ -554,7 +554,7 @@ CSRMatrix DisjointUnionCsr(
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr); std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(const CSRMatrix& csr);
/*! /*!
* \brief Split a CSRMatrix into multiple disjoin components. * \brief Split a CSRMatrix into multiple disjoint components.
* *
* Examples: * Examples:
* *
...@@ -604,6 +604,47 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -604,6 +604,47 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
const std::vector<uint64_t> &src_vertex_cumsum, const std::vector<uint64_t> &src_vertex_cumsum,
const std::vector<uint64_t> &dst_vertex_cumsum); const std::vector<uint64_t> &dst_vertex_cumsum);
/*!
* \brief Slice a contiguous chunk from a CSRMatrix
*
* Examples:
*
* C = [[0, 0, 1, 0, 0],
* [1, 0, 1, 0, 0],
* [0, 1, 0, 0, 0],
* [0, 0, 0, 0, 0],
* [0, 0, 0, 1, 0],
* [0, 0, 0, 0, 1]]
* CSRMatrix_C.num_rows : 6
* CSRMatrix_C.num_cols : 5
*
* edge_range : [4, 6]
* src_vertex_range : [3, 6]
* dst_vertex_range : [3, 5]
*
* ret = CSRSliceContiguousChunk(C,
* edge_range,
* src_vertex_range,
* dst_vertex_range)
*
* ret = [[0, 0],
* [1, 0],
* [0, 1]]
* CSRMatrix_ret.num_rows : 3
* CSRMatrix_ret.num_cols : 2
*
* \param csr CSRMatrix to slice.
* \param edge_range ID range of the edges in the chunk
* \param src_vertex_range ID range of the src vertices in the chunk.
* \param dst_vertex_range ID range of the dst vertices in the chunk.
* \return CSRMatrix representing the chunk.
*/
CSRMatrix CSRSliceContiguousChunk(
const CSRMatrix &csr,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -732,6 +732,27 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -732,6 +732,27 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
HeteroGraphPtr DisjointUnionHeteroGraph2( HeteroGraphPtr DisjointUnionHeteroGraph2(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs); GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
/*!
* \brief Slice a contiguous subgraph, e.g. retrieve a component graph from a batched graph.
*
* TODO(mufei): remove the meta_graph argument
*
* \param meta_graph Metagraph of the input and result.
* \param batched_graph Input graph.
* \param num_nodes_per_type Number of vertices of each type in the result.
* \param start_nid_per_type Start vertex ID of each type to slice.
* \param num_edges_per_type Number of edges of each type in the result.
* \param start_eid_per_type Start edge ID of each type to slice.
* \return Sliced graph
*/
HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph,
HeteroGraphPtr batched_graph,
IdArray num_nodes_per_type,
IdArray start_nid_per_type,
IdArray num_edges_per_type,
IdArray start_eid_per_type);
/*! /*!
* \brief Split a graph into multiple disjoin components. * \brief Split a graph into multiple disjoin components.
* *
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
from collections.abc import Mapping from collections.abc import Mapping
from . import backend as F from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning, NID, EID
from .heterograph_index import disjoint_union from .heterograph_index import disjoint_union, slice_gidx
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import convert from . import convert
from . import utils from . import utils
__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero'] __all__ = ['batch', 'unbatch', 'slice_batch', 'batch_hetero', 'unbatch_hetero']
def batch(graphs, ndata=ALL, edata=ALL, *, def batch(graphs, ndata=ALL, edata=ALL, *,
node_attrs=None, edge_attrs=None): node_attrs=None, edge_attrs=None):
...@@ -416,6 +416,72 @@ def unbatch(g, node_split=None, edge_split=None): ...@@ -416,6 +416,72 @@ def unbatch(g, node_split=None, edge_split=None):
return gs return gs
def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs.
Parameters
----------
g : DGLGraph
Input batched graph.
gid : int
The ID of the graph to retrieve.
store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` and
``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
Returns
-------
DGLGraph
Retrieved graph.
"""
start_nid = []
num_nodes = []
for ntype in g.ntypes:
batch_num_nodes = g.batch_num_nodes(ntype)
num_nodes.append(F.as_scalar(batch_num_nodes[gid]))
if gid == 0:
start_nid.append(0)
else:
start_nid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0)))
start_eid = []
num_edges = []
for etype in g.canonical_etypes:
batch_num_edges = g.batch_num_edges(etype)
num_edges.append(F.as_scalar(batch_num_edges[gid]))
if gid == 0:
start_eid.append(0)
else:
start_eid.append(F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0)))
# Slice graph structure
gidx = slice_gidx(g._graph, utils.toindex(num_nodes), utils.toindex(start_nid),
utils.toindex(num_edges), utils.toindex(start_eid))
retg = DGLHeteroGraph(gidx, g.ntypes, g.etypes)
# Slice node features
for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid+num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats
if store_ids:
retg.nodes[ntype].data[NID] = F.arange(stnid, stnid+num_nodes[ntid],
retg.idtype, retg.device)
# Slice edge features
for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid]
for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid+num_edges[etid])
retg.edges[etype].data[key] = subfeats
if store_ids:
retg.edges[etype].data[EID] = F.arange(steid, steid+num_edges[etid],
retg.idtype, retg.device)
return retg
#### DEPRECATED APIS #### #### DEPRECATED APIS ####
def batch_hetero(*args, **kwargs): def batch_hetero(*args, **kwargs):
......
...@@ -1231,6 +1231,31 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types): ...@@ -1231,6 +1231,31 @@ def disjoint_partition(graph, bnn_all_types, bne_all_types):
return _CAPI_DGLHeteroDisjointPartitionBySizes_v2( return _CAPI_DGLHeteroDisjointPartitionBySizes_v2(
graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor()) graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor())
def slice_gidx(graph, num_nodes, start_nid, num_edges, start_eid):
"""Slice a chunk of the graph.
Parameters
----------
graph : HeteroGraphIndex
The batched graph to slice.
num_nodes : utils.Index
Number of nodes per node type in the result graph.
start_nid : utils.Index
Start node ID per node type in the result graph.
num_edges : utils.Index
Number of edges per edge type in the result graph.
start_eid : utils.Index
Start edge ID per edge type in the result graph.
Returns
-------
HeteroGraphIndex
The sliced graph.
"""
return _CAPI_DGLHeteroSlice(
graph, num_nodes.todgltensor(), start_nid.todgltensor(),
num_edges.todgltensor(), start_eid.todgltensor())
################################################################# #################################################################
# Data structure used by C APIs # Data structure used by C APIs
################################################################# #################################################################
......
...@@ -110,6 +110,43 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -110,6 +110,43 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
return ret; return ret;
} }
COOMatrix COOSliceContiguousChunk(
const COOMatrix &coo,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range) {
IdArray result_src = NullArray(coo.row->dtype, coo.row->ctx);
IdArray result_dst = NullArray(coo.row->dtype, coo.row->ctx);
if (edge_range[1] != edge_range[0]) {
// The chunk has edges
result_src = IndexSelect(coo.row,
edge_range[0],
edge_range[1]) - src_vertex_range[0];
result_dst = IndexSelect(coo.col,
edge_range[0],
edge_range[1]) - dst_vertex_range[0];
}
IdArray result_data = NullArray();
// has data index array
if (COOHasData(coo)) {
result_data = IndexSelect(coo.data,
edge_range[0],
edge_range[1]) - edge_range[0];
}
COOMatrix sub_coo = COOMatrix(
src_vertex_range[1]-src_vertex_range[0],
dst_vertex_range[1]-dst_vertex_range[0],
result_src,
result_dst,
result_data,
coo.row_sorted,
coo.col_sorted);
return sub_coo;
}
///////////////////////// CSR Based Operations///////////////////////// ///////////////////////// CSR Based Operations/////////////////////////
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
uint64_t src_offset = 0, dst_offset = 0; uint64_t src_offset = 0, dst_offset = 0;
...@@ -222,5 +259,40 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes( ...@@ -222,5 +259,40 @@ std::vector<CSRMatrix> DisjointPartitionCsrBySizes(
return ret; return ret;
} }
CSRMatrix CSRSliceContiguousChunk(
const CSRMatrix &csr,
const std::vector<uint64_t> &edge_range,
const std::vector<uint64_t> &src_vertex_range,
const std::vector<uint64_t> &dst_vertex_range) {
int64_t indptr_len = src_vertex_range[1] - src_vertex_range[0] + 1;
IdArray result_indptr = Full(0, indptr_len, csr.indptr->dtype.bits, csr.indptr->ctx);
IdArray result_indices = NullArray(csr.indptr->dtype, csr.indptr->ctx);
IdArray result_data = NullArray();
if (edge_range[1] != edge_range[0]) {
// The chunk has edges
result_indptr = IndexSelect(csr.indptr,
src_vertex_range[0],
src_vertex_range[1] + 1) - edge_range[0];
result_indices = IndexSelect(csr.indices,
edge_range[0],
edge_range[1]) - dst_vertex_range[0];
if (CSRHasData(csr)) {
result_data = IndexSelect(csr.data,
edge_range[0],
edge_range[1]) - edge_range[0];
}
}
CSRMatrix sub_csr = CSRMatrix(
src_vertex_range[1]-src_vertex_range[0],
dst_vertex_range[1]-dst_vertex_range[0],
result_indptr,
result_indices,
result_data,
csr.sorted);
return sub_csr;
}
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -581,6 +581,18 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") ...@@ -581,6 +581,18 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSlice")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray num_nodes_per_type = args[1];
const IdArray start_nid_per_type = args[2];
const IdArray num_edges_per_type = args[3];
const IdArray start_eid_per_type = args[4];
auto hgptr = SliceHeteroGraph(hg->meta_graph(), hg.sptr(), num_nodes_per_type,
start_nid_per_type, num_edges_per_type, start_eid_per_type);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -269,6 +269,71 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -269,6 +269,71 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
return rst; return rst;
} }
HeteroGraphPtr SliceHeteroGraph(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray num_nodes_per_type,
IdArray start_nid_per_type, IdArray num_edges_per_type, IdArray start_eid_per_type) {
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
const uint64_t* start_nid_per_type_data = static_cast<uint64_t*>(start_nid_per_type->data);
const uint64_t* num_nodes_per_type_data = static_cast<uint64_t*>(num_nodes_per_type->data);
const uint64_t* start_eid_per_type_data = static_cast<uint64_t*>(start_eid_per_type->data);
const uint64_t* num_edges_per_type_data = static_cast<uint64_t*>(num_edges_per_type->data);
// Map vertex type to the corresponding node range
const uint64_t num_vertex_types = meta_graph->NumVertices();
std::vector<std::vector<uint64_t>> vertex_range;
vertex_range.resize(num_vertex_types);
// Loop over all vertex types
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_range[vtype].push_back(start_nid_per_type_data[vtype]);
vertex_range[vtype].push_back(
start_nid_per_type_data[vtype] + num_nodes_per_type_data[vtype]);
}
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code = batched_graph->GetRelationGraph(etype)->GetAllowedFormats();
// handle graph without edges
std::vector<uint64_t> edge_range;
edge_range.push_back(start_eid_per_type_data[etype]);
edge_range.push_back(start_eid_per_type_data[etype] + num_edges_per_type_data[etype]);
// prefer COO
if (FORMAT_HAS_COO(code)) {
aten::COOMatrix coo = batched_graph->GetCOOMatrix(etype);
aten::COOMatrix res = aten::COOSliceContiguousChunk(coo,
edge_range,
vertex_range[src_vtype],
vertex_range[dst_vtype]);
rgptr = UnitGraph::CreateFromCOO((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) {
aten::CSRMatrix csr = batched_graph->GetCSRMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csr,
edge_range,
vertex_range[src_vtype],
vertex_range[dst_vtype]);
rgptr = UnitGraph::CreateFromCSR((src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix
aten::CSRMatrix csc = batched_graph->GetCSCMatrix(etype);
aten::CSRMatrix res = aten::CSRSliceContiguousChunk(csc,
edge_range,
vertex_range[dst_vtype],
vertex_range[src_vtype]);
rgptr = UnitGraph::CreateFromCSC((src_vtype == dst_vtype) ? 1 : 2, res, code);
}
rel_graphs[etype] = rgptr;
}
return CreateHeteroGraph(meta_graph, rel_graphs, num_nodes_per_type.ToVector<int64_t>());
}
template <class IdType> template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
......
...@@ -333,6 +333,49 @@ def test_unbatch2(idtype): ...@@ -333,6 +333,49 @@ def test_unbatch2(idtype):
check_graph_equal(g3, gg3) check_graph_equal(g3, gg3)
@parametrize_dtype
def test_slice_batch(idtype):
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([], []),
('user', 'follows', 'game'): ([0, 0], [1, 4])
}, idtype=idtype, device=F.ctx())
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1], [0, 0]),
('user', 'follows', 'game'): ([0, 1], [1, 4])
}, num_nodes_dict={'user': 4, 'game': 6}, idtype=idtype, device=F.ctx())
g3 = dgl.heterograph({
('user', 'follows', 'user'): ([0], [2]),
('user', 'plays', 'game'): ([1, 2], [3, 4]),
('user', 'follows', 'game'): ([], [])
}, idtype=idtype, device=F.ctx())
g_list = [g1, g2, g3]
bg = dgl.batch(g_list)
bg.nodes['user'].data['h1'] = F.randn((bg.num_nodes('user'), 2))
bg.nodes['user'].data['h2'] = F.randn((bg.num_nodes('user'), 5))
bg.edges[('user', 'follows', 'user')].data['h1'] = F.randn((
bg.num_edges(('user', 'follows', 'user')), 2))
for fmat in ['coo', 'csr', 'csc']:
bg = bg.formats(fmat)
for i in range(len(g_list)):
g_i = g_list[i]
g_slice = dgl.slice_batch(bg, i)
assert g_i.ntypes == g_slice.ntypes
assert g_i.canonical_etypes == g_slice.canonical_etypes
assert g_i.idtype == g_slice.idtype
assert g_i.device == g_slice.device
for nty in g_i.ntypes:
assert g_i.num_nodes(nty) == g_slice.num_nodes(nty)
for feat in g_i.nodes[nty].data:
assert F.allclose(g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat])
for ety in g_i.canonical_etypes:
assert g_i.num_edges(ety) == g_slice.num_edges(ety)
for feat in g_i.edges[ety].data:
assert F.allclose(g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat])
@parametrize_dtype @parametrize_dtype
def test_batch_keeps_empty_data(idtype): def test_batch_keeps_empty_data(idtype):
g1 = dgl.heterograph({("a", "to", "a"): ([], [])} g1 = dgl.heterograph({("a", "to", "a"): ([], [])}
......
...@@ -781,6 +781,175 @@ TEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) { ...@@ -781,6 +781,175 @@ TEST(DisjointUnionTest, TestDisjointUnionPartitionCsr) {
#endif #endif
} }
template <typename IdType>
void _TestSliceContiguousChunkCoo(DLContext ctx) {
/*
* A = [[1, 0, 0, 0],
* [0, 0, 1, 0],
* [0, 0, 0, 0]]
*
* B = [[1, 0, 0],
* [0, 0, 1]]
*
* C = [[0]]
*
*/
IdArray a_row = aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType)*8, CTX);
IdArray a_col = aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_a = aten::COOMatrix(
3,
4,
a_row,
a_col,
aten::NullArray(),
true,
false);
IdArray b_row = aten::VecToIdArray(std::vector<IdType>({0, 1}), sizeof(IdType)*8, CTX);
IdArray b_col = aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_b_raw = aten::COOMatrix(
2,
3,
b_row,
b_col,
aten::NullArray(),
true,
false);
const std::vector<uint64_t> edge_range_b({0, 2});
const std::vector<uint64_t> src_vertex_range_b({0, 2});
const std::vector<uint64_t> dst_vertex_range_b({0, 3});
const aten::COOMatrix &coo_b = aten::COOSliceContiguousChunk(
coo_a,
edge_range_b,
src_vertex_range_b,
dst_vertex_range_b);
ASSERT_EQ(coo_b_raw.num_rows, coo_b.num_rows);
ASSERT_EQ(coo_b_raw.num_cols, coo_b.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(coo_b_raw.row, coo_b.row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_b_raw.col, coo_b.col));
ASSERT_TRUE(coo_b.row_sorted);
ASSERT_FALSE(coo_b.col_sorted);
IdArray c_row = aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType)*8, CTX);
IdArray c_col = aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo_c_raw = aten::COOMatrix(
1,
1,
c_row,
c_col,
aten::NullArray(),
true,
false);
const std::vector<uint64_t> edge_range_c({2, 2});
const std::vector<uint64_t> src_vertex_range_c({2, 3});
const std::vector<uint64_t> dst_vertex_range_c({3, 4});
const aten::COOMatrix &coo_c = aten::COOSliceContiguousChunk(
coo_a,
edge_range_c,
src_vertex_range_c,
dst_vertex_range_c);
ASSERT_EQ(coo_c_raw.num_rows, coo_c.num_rows);
ASSERT_EQ(coo_c_raw.num_cols, coo_c.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(coo_c.row, c_row));
ASSERT_TRUE(ArrayEQ<IdType>(coo_c.col, c_col));
ASSERT_TRUE(coo_c.row_sorted);
ASSERT_FALSE(coo_c.col_sorted);
}
TEST(SliceContiguousChunk, TestSliceContiguousChunkCoo) {
_TestSliceContiguousChunkCoo<int32_t>(CPU);
_TestSliceContiguousChunkCoo<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestSliceContiguousChunkCoo<int32_t>(GPU);
_TestSliceContiguousChunkCoo<int64_t>(GPU);
#endif
}
template <typename IdType>
void _TestSliceContiguousChunkCsr(DLContext ctx) {
/*
* A = [[1, 0, 0, 0],
* [0, 0, 1, 0],
* [0, 0, 0, 0]]
*
* B = [[1, 0, 0],
* [0, 0, 1]]
*
* C = [[0]]
*
*/
IdArray a_indptr = aten::VecToIdArray(std::vector<IdType>({0, 1, 2, 2}), sizeof(IdType)*8, CTX);
IdArray a_indices = aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_a = aten::CSRMatrix(
3,
4,
a_indptr,
a_indices,
aten::NullArray(),
false);
IdArray b_indptr = aten::VecToIdArray(std::vector<IdType>({0, 1, 2}), sizeof(IdType)*8, CTX);
IdArray b_indices = aten::VecToIdArray(std::vector<IdType>({0, 2}), sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_b_raw = aten::CSRMatrix(
2,
3,
b_indptr,
b_indices,
aten::NullArray(),
false);
const std::vector<uint64_t> edge_range_b({0, 2});
const std::vector<uint64_t> src_vertex_range_b({0, 2});
const std::vector<uint64_t> dst_vertex_range_b({0, 3});
const aten::CSRMatrix &csr_b = aten::CSRSliceContiguousChunk(
csr_a,
edge_range_b,
src_vertex_range_b,
dst_vertex_range_b);
ASSERT_EQ(csr_b.num_rows, csr_b_raw.num_rows);
ASSERT_EQ(csr_b.num_cols, csr_b_raw.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indptr, csr_b_raw.indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_b.indices, csr_b_raw.indices));
ASSERT_FALSE(csr_b.sorted);
const std::vector<uint64_t> edge_range_c({2, 2});
const std::vector<uint64_t> src_vertex_range_c({2, 3});
const std::vector<uint64_t> dst_vertex_range_c({3, 4});
const aten::CSRMatrix &csr_c = aten::CSRSliceContiguousChunk(
csr_a,
edge_range_c,
src_vertex_range_c,
dst_vertex_range_c);
int64_t indptr_len = src_vertex_range_c[1] - src_vertex_range_c[0] + 1;
IdArray c_indptr = aten::Full(0, indptr_len, sizeof(IdType)*8, CTX);
IdArray c_indices = aten::VecToIdArray(std::vector<IdType>({}), sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_c_raw = aten::CSRMatrix(
1,
1,
c_indptr,
c_indices,
aten::NullArray(),
false);
ASSERT_EQ(csr_c.num_rows, csr_c_raw.num_rows);
ASSERT_EQ(csr_c.num_cols, csr_c_raw.num_cols);
ASSERT_TRUE(ArrayEQ<IdType>(csr_c.indptr, c_indptr));
ASSERT_TRUE(ArrayEQ<IdType>(csr_c.indices, c_indices));
ASSERT_FALSE(csr_c.sorted);
}
TEST(SliceContiguousChunk, TestSliceContiguousChunkCsr) {
_TestSliceContiguousChunkCsr<int32_t>(CPU);
_TestSliceContiguousChunkCsr<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestSliceContiguousChunkCsr<int32_t>(GPU);
_TestSliceContiguousChunkCsr<int64_t>(GPU);
#endif
}
template <typename IdType> template <typename IdType>
void _TestMatrixUnionCsr(DLContext ctx) { void _TestMatrixUnionCsr(DLContext ctx) {
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment