Unverified Commit ea420c0a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Reduce memory consumption in graph partitioning (#1823)



* save mem.

* save mem.

* reduce mem

* fix test

* fix lint

* fix test

* fix.

* fix.

* fix.

* fix.

* fix lint.

* fix backend operator.

* fix tensorflow operators.

* fix.

* revert change in mxnet operator.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-1.us-west-2.compute.internal>
parent 7645c660
...@@ -428,9 +428,6 @@ struct NegSubgraph: public Subgraph { ...@@ -428,9 +428,6 @@ struct NegSubgraph: public Subgraph {
struct HaloSubgraph: public Subgraph { struct HaloSubgraph: public Subgraph {
/*! \brief Indicate if a node belongs to the partition. */ /*! \brief Indicate if a node belongs to the partition. */
IdArray inner_nodes; IdArray inner_nodes;
/*! \brief Indicate if an edge belongs to the partition. */
IdArray inner_edges;
}; };
// Define SubgraphRef // Define SubgraphRef
......
...@@ -356,7 +356,7 @@ def nonzero_1d(input): ...@@ -356,7 +356,7 @@ def nonzero_1d(input):
# TODO: fallback to numpy is unfortunate # TODO: fallback to numpy is unfortunate
tmp = input.asnumpy() tmp = input.asnumpy()
tmp = np.nonzero(tmp)[0] tmp = np.nonzero(tmp)[0]
return nd.array(tmp, ctx=input.context, dtype=input.dtype) return nd.array(tmp, ctx=input.context, dtype=tmp.dtype)
def sort_1d(input): def sort_1d(input):
# TODO: this isn't an ideal implementation. # TODO: this isn't an ideal implementation.
......
...@@ -63,7 +63,8 @@ def initialize_context(): ...@@ -63,7 +63,8 @@ def initialize_context():
tf.zeros(1) tf.zeros(1)
def as_scalar(data): def as_scalar(data):
return data.numpy().asscalar() data = data.numpy()
return data if np.isscalar(data) else data.asscalar()
def get_preferred_sparse_format(): def get_preferred_sparse_format():
...@@ -384,7 +385,7 @@ def full_1d(length, fill_value, dtype, ctx): ...@@ -384,7 +385,7 @@ def full_1d(length, fill_value, dtype, ctx):
def nonzero_1d(input): def nonzero_1d(input):
nonzero_bool = (input != False) nonzero_bool = tf.cast(input, tf.bool)
return tf.reshape(tf.where(nonzero_bool), (-1, )) return tf.reshape(tf.where(nonzero_bool), (-1, ))
......
...@@ -101,10 +101,10 @@ class GraphPartitionBook: ...@@ -101,10 +101,10 @@ class GraphPartitionBook:
self._part_id = int(part_id) self._part_id = int(part_id)
self._num_partitions = int(num_parts) self._num_partitions = int(num_parts)
self._nid2partid = F.tensor(node_map) self._nid2partid = F.tensor(node_map)
assert F.dtype(self._nid2partid) in (F.int32, F.int64), \ assert F.dtype(self._nid2partid) == F.int64, \
'the node map must be stored in an integer array' 'the node map must be stored in an integer array'
self._eid2partid = F.tensor(edge_map) self._eid2partid = F.tensor(edge_map)
assert F.dtype(self._eid2partid) in (F.int32, F.int64), \ assert F.dtype(self._eid2partid) == F.int64, \
'the edge map must be stored in an integer array' 'the edge map must be stored in an integer array'
# Get meta data of the partition book. # Get meta data of the partition book.
self._partition_meta_data = [] self._partition_meta_data = []
......
...@@ -254,49 +254,36 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -254,49 +254,36 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
Indicate whether to balance the edges. Indicate whether to balance the edges.
''' '''
if num_parts == 1: if num_parts == 1:
client_parts = {0: g} parts = {0: g}
node_parts = F.zeros((g.number_of_nodes(),), F.int64, F.cpu()) node_parts = F.zeros((g.number_of_nodes(),), F.int64, F.cpu())
g.ndata[NID] = F.arange(0, g.number_of_nodes()) g.ndata[NID] = F.arange(0, g.number_of_nodes())
g.edata[EID] = F.arange(0, g.number_of_edges()) g.edata[EID] = F.arange(0, g.number_of_edges())
g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int64, F.cpu()) g.ndata['inner_node'] = F.ones((g.number_of_nodes(),), F.int8, F.cpu())
g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int64, F.cpu()) g.edata['inner_edge'] = F.ones((g.number_of_edges(),), F.int8, F.cpu())
if reshuffle: if reshuffle:
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes()) g.ndata['orig_id'] = F.arange(0, g.number_of_nodes())
g.edata['orig_id'] = F.arange(0, g.number_of_edges()) g.edata['orig_id'] = F.arange(0, g.number_of_edges())
elif part_method == 'metis': elif part_method == 'metis':
node_parts = metis_partition_assignment(g, num_parts, balance_ntypes=balance_ntypes, node_parts = metis_partition_assignment(g, num_parts, balance_ntypes=balance_ntypes,
balance_edges=balance_edges) balance_edges=balance_edges)
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle) parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
elif part_method == 'random': elif part_method == 'random':
node_parts = random_choice(num_parts, g.number_of_nodes()) node_parts = random_choice(num_parts, g.number_of_nodes())
client_parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle) parts = partition_graph_with_halo(g, node_parts, num_hops, reshuffle=reshuffle)
else: else:
raise Exception('Unknown partitioning method: ' + part_method) raise Exception('Unknown partitioning method: ' + part_method)
# Let's calculate edge assignment. # Let's calculate edge assignment.
# TODO(zhengda) we should replace int64 with int16. int16 should be sufficient.
start = time.time()
if not reshuffle: if not reshuffle:
start = time.time()
# We only optimize for reshuffled case. So it's fine to use int64 here.
edge_parts = np.zeros((g.number_of_edges(),), dtype=np.int64) - 1 edge_parts = np.zeros((g.number_of_edges(),), dtype=np.int64) - 1
num_edges = 0 for part_id in parts:
num_nodes = 0 part = parts[part_id]
lnodes_list = [] # The node ids of each partition # To get the edges in the input graph, we should use original node Ids.
ledges_list = [] # The edge Ids of each partition local_edges = F.boolean_mask(part.edata[EID], part.edata['inner_edge'])
for part_id in range(num_parts):
part = client_parts[part_id]
# To get the edges in the input graph, we should use original node Ids.
data_name = 'orig_id' if reshuffle else NID
local_nodes = F.boolean_mask(part.ndata[data_name], part.ndata['inner_node'])
local_edges = g.in_edges(local_nodes, form='eid')
if not reshuffle:
edge_parts[F.asnumpy(local_edges)] = part_id edge_parts[F.asnumpy(local_edges)] = part_id
num_edges += len(local_edges) print('Calculate edge assignment: {:.3f} seconds'.format(time.time() - start))
num_nodes += len(local_nodes)
lnodes_list.append(local_nodes)
ledges_list.append(local_edges)
assert num_edges == g.number_of_edges()
assert num_nodes == g.number_of_nodes()
print('Calculate edge assignment: {:.3f} seconds'.format(time.time() - start))
os.makedirs(out_path, mode=0o775, exist_ok=True) os.makedirs(out_path, mode=0o775, exist_ok=True)
tot_num_inner_edges = 0 tot_num_inner_edges = 0
...@@ -314,8 +301,14 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -314,8 +301,14 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
# With reshuffling, we can ensure that all nodes and edges are reshuffled # With reshuffling, we can ensure that all nodes and edges are reshuffled
# and are in contiguous Id space. # and are in contiguous Id space.
if num_parts > 1: if num_parts > 1:
node_map_val = np.cumsum([len(lnodes) for lnodes in lnodes_list]).tolist() node_map_val = [F.as_scalar(F.sum(F.astype(parts[i].ndata['inner_node'], F.int64),
edge_map_val = np.cumsum([len(ledges) for ledges in ledges_list]).tolist() 0)) for i in parts]
node_map_val = np.cumsum(node_map_val).tolist()
assert node_map_val[-1] == g.number_of_nodes()
edge_map_val = [F.as_scalar(F.sum(F.astype(parts[i].edata['inner_edge'], F.int64),
0)) for i in parts]
edge_map_val = np.cumsum(edge_map_val).tolist()
assert edge_map_val[-1] == g.number_of_edges()
else: else:
node_map_val = [g.number_of_nodes()] node_map_val = [g.number_of_nodes()]
edge_map_val = [g.number_of_edges()] edge_map_val = [g.number_of_edges()]
...@@ -330,14 +323,17 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -330,14 +323,17 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
'node_map': node_map_val, 'node_map': node_map_val,
'edge_map': edge_map_val} 'edge_map': edge_map_val}
for part_id in range(num_parts): for part_id in range(num_parts):
part = client_parts[part_id] part = parts[part_id]
# Get the node/edge features of each partition. # Get the node/edge features of each partition.
node_feats = {} node_feats = {}
edge_feats = {} edge_feats = {}
if num_parts > 1: if num_parts > 1:
local_nodes = lnodes_list[part_id] # To get the edges in the input graph, we should use original node Ids.
local_edges = ledges_list[part_id] ndata_name = 'orig_id' if reshuffle else NID
edata_name = 'orig_id' if reshuffle else EID
local_nodes = F.boolean_mask(part.ndata[ndata_name], part.ndata['inner_node'])
local_edges = F.boolean_mask(part.edata[edata_name], part.edata['inner_edge'])
print('part {} has {} nodes and {} edges.'.format( print('part {} has {} nodes and {} edges.'.format(
part_id, part.number_of_nodes(), part.number_of_edges())) part_id, part.number_of_nodes(), part.number_of_edges()))
print('{} nodes and {} edges are inside the partition'.format( print('{} nodes and {} edges are inside the partition'.format(
......
...@@ -566,8 +566,7 @@ class GraphIndex(ObjectBase): ...@@ -566,8 +566,7 @@ class GraphIndex(ObjectBase):
v_array = v.todgltensor() v_array = v.todgltensor()
subg = _CAPI_DGLGetSubgraphWithHalo(self, v_array, num_hops) subg = _CAPI_DGLGetSubgraphWithHalo(self, v_array, num_hops)
inner_nodes = _CAPI_GetHaloSubgraphInnerNodes(subg) inner_nodes = _CAPI_GetHaloSubgraphInnerNodes(subg)
inner_edges = _CAPI_GetHaloSubgraphInnerEdges(subg) return subg, inner_nodes
return subg, inner_nodes, inner_edges
def node_subgraphs(self, vs_arr): def node_subgraphs(self, vs_arr):
"""Return the induced node subgraphs. """Return the induced node subgraphs.
...@@ -1297,7 +1296,4 @@ def create_graph_index(graph_data, readonly): ...@@ -1297,7 +1296,4 @@ def create_graph_index(graph_data, readonly):
def _get_halo_subgraph_inner_node(halo_subg): def _get_halo_subgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes(halo_subg) return _CAPI_GetHaloSubgraphInnerNodes(halo_subg)
def _get_halo_subgraph_inner_edge(halo_subg):
return _CAPI_GetHaloSubgraphInnerEdges(halo_subg)
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
...@@ -962,27 +962,38 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -962,27 +962,38 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
# that all edges in a partition are in the contiguous Id space. # that all edges in a partition are in the contiguous Id space.
orig_eids = _CAPI_DGLReassignEdges(g._graph, True) orig_eids = _CAPI_DGLReassignEdges(g._graph, True)
orig_eids = utils.toindex(orig_eids) orig_eids = utils.toindex(orig_eids)
g.edata['orig_id'] = orig_eids.tousertensor() orig_eids = orig_eids.tousertensor()
orig_nids = g.ndata['orig_id']
print('Reshuffle nodes and edges: {:.3f} seconds'.format(time.time() - start)) print('Reshuffle nodes and edges: {:.3f} seconds'.format(time.time() - start))
start = time.time() start = time.time()
subgs = _CAPI_DGLPartitionWithHalo(g._graph, node_part.todgltensor(), extra_cached_hops) subgs = _CAPI_DGLPartitionWithHalo(g._graph, node_part.todgltensor(), extra_cached_hops)
# g is no longer needed. Free memory.
g = None
print('Split the graph: {:.3f} seconds'.format(time.time() - start)) print('Split the graph: {:.3f} seconds'.format(time.time() - start))
subg_dict = {} subg_dict = {}
node_part = node_part.tousertensor() node_part = node_part.tousertensor()
start = time.time() start = time.time()
# This creaets a subgraph from subgraphs returned from the CAPI above.
def create_subgraph(subg, induced_nodes, induced_edges):
subg1 = DGLGraph(graph_data=subg.graph, readonly=True)
subg1.ndata[NID] = induced_nodes.tousertensor()
subg1.edata[EID] = induced_edges.tousertensor()
return subg1
for i, subg in enumerate(subgs): for i, subg in enumerate(subgs):
inner_node = _get_halo_subgraph_inner_node(subg) inner_node = _get_halo_subgraph_inner_node(subg)
subg = g._create_subgraph(subg, subg.induced_nodes, subg.induced_edges) subg = create_subgraph(subg, subg.induced_nodes, subg.induced_edges)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack()) inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg.ndata['inner_node'] = inner_node subg.ndata['inner_node'] = inner_node
subg.ndata['part_id'] = F.gather_row(node_part, subg.parent_nid) subg.ndata['part_id'] = F.gather_row(node_part, subg.ndata[NID])
if reshuffle: if reshuffle:
subg.ndata['orig_id'] = F.gather_row(g.ndata['orig_id'], subg.ndata[NID]) subg.ndata['orig_id'] = F.gather_row(orig_nids, subg.ndata[NID])
subg.edata['orig_id'] = F.gather_row(g.edata['orig_id'], subg.edata[EID]) subg.edata['orig_id'] = F.gather_row(orig_eids, subg.edata[EID])
if extra_cached_hops >= 1: if extra_cached_hops >= 1:
inner_edge = F.zeros((subg.number_of_edges(),), F.int64, F.cpu()) inner_edge = F.zeros((subg.number_of_edges(),), F.int8, F.cpu())
inner_nids = F.nonzero_1d(subg.ndata['inner_node']) inner_nids = F.nonzero_1d(subg.ndata['inner_node'])
# TODO(zhengda) we need to fix utils.toindex() to avoid the dtype cast below. # TODO(zhengda) we need to fix utils.toindex() to avoid the dtype cast below.
inner_nids = F.astype(inner_nids, F.int64) inner_nids = F.astype(inner_nids, F.int64)
...@@ -990,7 +1001,7 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -990,7 +1001,7 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
inner_edge = F.scatter_row(inner_edge, inner_eids, inner_edge = F.scatter_row(inner_edge, inner_eids,
F.ones((len(inner_eids),), F.dtype(inner_edge), F.cpu())) F.ones((len(inner_eids),), F.dtype(inner_edge), F.cpu()))
else: else:
inner_edge = F.ones((subg.number_of_edges(),), F.int64, F.cpu()) inner_edge = F.ones((subg.number_of_edges(),), F.int8, F.cpu())
subg.edata['inner_edge'] = inner_edge subg.edata['inner_edge'] = inner_edge
subg_dict[i] = subg subg_dict[i] = subg
print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start)) print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start))
......
...@@ -416,7 +416,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -416,7 +416,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
auto orig_nodes = all_nodes; auto orig_nodes = all_nodes;
std::vector<dgl_id_t> edge_src, edge_dst, edge_eid; std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;
std::vector<int> inner_edges;
// When we deal with in-edges, we need to do two things: // When we deal with in-edges, we need to do two things:
// * find the edges inside the partition and the edges between partitions. // * find the edges inside the partition and the edges between partitions.
...@@ -436,7 +435,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -436,7 +435,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
edge_src.push_back(src_data[i]); edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]); edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]); edge_eid.push_back(eid_data[i]);
inner_edges.push_back(it1 != orig_nodes.end());
} }
// We need to expand only if the node hasn't been seen before. // We need to expand only if the node hasn't been seen before.
auto it = all_nodes.find(src_data[i]); auto it = all_nodes.find(src_data[i]);
...@@ -463,7 +461,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -463,7 +461,6 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
edge_src.push_back(src_data[i]); edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]); edge_dst.push_back(dst_data[i]);
edge_eid.push_back(eid_data[i]); edge_eid.push_back(eid_data[i]);
inner_edges.push_back(false);
// If we haven't seen this node. // If we haven't seen this node.
auto it = all_nodes.find(src_data[i]); auto it = all_nodes.find(src_data[i]);
if (it == all_nodes.end()) { if (it == all_nodes.end()) {
...@@ -502,8 +499,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -502,8 +499,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
halo_subg.graph = subg; halo_subg.graph = subg;
halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids); halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);
halo_subg.induced_edges = aten::VecToIdArray(edge_eid); halo_subg.induced_edges = aten::VecToIdArray(edge_eid);
// TODO(zhengda) we need to switch to 8 bytes afterwards.
halo_subg.inner_nodes = aten::VecToIdArray<int>(inner_nodes, 32); halo_subg.inner_nodes = aten::VecToIdArray<int>(inner_nodes, 32);
halo_subg.inner_edges = aten::VecToIdArray<int>(inner_edges, 32);
return halo_subg; return halo_subg;
} }
...@@ -603,14 +600,6 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes") ...@@ -603,14 +600,6 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes")
*rv = gptr->inner_nodes; *rv = gptr->inner_nodes;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be immutable graph";
*rv = gptr->inner_edges;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0]; List<GraphRef> graphs = args[0];
......
...@@ -466,13 +466,13 @@ def test_partition_with_halo(): ...@@ -466,13 +466,13 @@ def test_partition_with_halo():
for i in range(nf.num_layers): for i in range(nf.num_layers):
layer_nids1 = F.asnumpy(nf.layer_parent_nid(i)) layer_nids1 = F.asnumpy(nf.layer_parent_nid(i))
layer_nids2 = lnf.layer_parent_nid(i) layer_nids2 = lnf.layer_parent_nid(i)
layer_nids2 = F.asnumpy(F.gather_row(subg.parent_nid, layer_nids2)) layer_nids2 = F.asnumpy(F.gather_row(subg.ndata[dgl.NID], layer_nids2))
assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2)) assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2))
for i in range(nf.num_blocks): for i in range(nf.num_blocks):
block_eids1 = F.asnumpy(nf.block_parent_eid(i)) block_eids1 = F.asnumpy(nf.block_parent_eid(i))
block_eids2 = lnf.block_parent_eid(i) block_eids2 = lnf.block_parent_eid(i)
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2)) block_eids2 = F.asnumpy(F.gather_row(subg.edata[dgl.EID], block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2)) assert np.all(np.sort(block_eids1) == np.sort(block_eids2))
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True) subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True)
......
...@@ -98,7 +98,7 @@ def create_large_graph_index(num_nodes): ...@@ -98,7 +98,7 @@ def create_large_graph_index(num_nodes):
def test_node_subgraph_with_halo(): def test_node_subgraph_with_halo():
gi = create_large_graph_index(1000) gi = create_large_graph_index(1000)
nodes = np.random.choice(gi.number_of_nodes(), 100, replace=False) nodes = np.random.choice(gi.number_of_nodes(), 100, replace=False)
halo_subg, inner_node, inner_edge = gi.node_halo_subgraph(toindex(nodes), 2) halo_subg, inner_node = gi.node_halo_subgraph(toindex(nodes), 2)
# Check if edges in the subgraph are in the original graph. # Check if edges in the subgraph are in the original graph.
for s, d, e in zip(*halo_subg.graph.edges()): for s, d, e in zip(*halo_subg.graph.edges()):
...@@ -111,12 +111,6 @@ def test_node_subgraph_with_halo(): ...@@ -111,12 +111,6 @@ def test_node_subgraph_with_halo():
inner_node_ids = halo_subg.induced_nodes.tonumpy()[inner_node_ids] inner_node_ids = halo_subg.induced_nodes.tonumpy()[inner_node_ids]
assert np.all(np.sort(inner_node_ids) == np.sort(nodes)) assert np.all(np.sort(inner_node_ids) == np.sort(nodes))
# Check if the inner edge labels are correct.
inner_edge = inner_edge.asnumpy()
inner_edge_ids = halo_subg.induced_edges.tonumpy()[inner_edge > 0]
subg = gi.node_subgraph(toindex(nodes))
assert np.all(np.sort(subg.induced_edges.tonumpy()) == np.sort(inner_edge_ids))
if __name__ == '__main__': if __name__ == '__main__':
test_node_subgraph() test_node_subgraph()
test_node_subgraph_with_halo() test_node_subgraph_with_halo()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment