Unverified Commit 015acfd2 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Small fixes (#1859)

* fix node/edge_split.

* fix partition.

* support heterograph interface.

* fix test.

* fix

* fix docstring.
parent ff8f7082
...@@ -372,6 +372,42 @@ class DistGraph: ...@@ -372,6 +372,42 @@ class DistGraph:
""" """
return self._edata return self._edata
@property
def ntypes(self):
"""Return the list of node types of this graph.
Returns
-------
list of str
Examples
--------
>>> g = DistGraph("ip_config.txt", "test")
>>> g.ntypes
['_U']
"""
# Currently, we only support a graph with one node type.
return ['_U']
@property
def etypes(self):
"""Return the list of edge types of this graph.
Returns
-------
list of str
Examples
--------
>>> g = DistGraph("ip_config.txt", "test")
>>> g.etypes
['_E']
"""
# Currently, we only support a graph with one edge type.
return ['_E']
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes""" """Return the number of nodes"""
return self._num_nodes return self._num_nodes
...@@ -544,7 +580,7 @@ def _split_even(partition_book, rank, elements): ...@@ -544,7 +580,7 @@ def _split_even(partition_book, rank, elements):
return eles[offsets[rank-1]:offsets[rank]] return eles[offsets[rank-1]:offsets[rank]]
def node_split(nodes, partition_book, rank=None, force_even=False): def node_split(nodes, partition_book=None, rank=None, force_even=True):
''' Split nodes and return a subset for the local rank. ''' Split nodes and return a subset for the local rank.
This function splits the input nodes based on the partition book and This function splits the input nodes based on the partition book and
...@@ -580,6 +616,10 @@ def node_split(nodes, partition_book, rank=None, force_even=False): ...@@ -580,6 +616,10 @@ def node_split(nodes, partition_book, rank=None, force_even=False):
The vector of node Ids that belong to the rank. The vector of node Ids that belong to the rank.
''' '''
num_nodes = 0 num_nodes = 0
if not isinstance(nodes, DistTensor):
assert partition_book is not None, 'Regular tensor requires a partition book.'
elif partition_book is None:
partition_book = nodes.part_policy.partition_book
for part in partition_book.metadata(): for part in partition_book.metadata():
num_nodes += part['num_nodes'] num_nodes += part['num_nodes']
assert len(nodes) == num_nodes, \ assert len(nodes) == num_nodes, \
...@@ -591,7 +631,7 @@ def node_split(nodes, partition_book, rank=None, force_even=False): ...@@ -591,7 +631,7 @@ def node_split(nodes, partition_book, rank=None, force_even=False):
local_nids = partition_book.partid2nids(partition_book.partid) local_nids = partition_book.partid2nids(partition_book.partid)
return _split_local(partition_book, rank, nodes, local_nids) return _split_local(partition_book, rank, nodes, local_nids)
def edge_split(edges, partition_book, rank=None, force_even=False): def edge_split(edges, partition_book=None, rank=None, force_even=True):
''' Split edges and return a subset for the local rank. ''' Split edges and return a subset for the local rank.
This function splits the input edges based on the partition book and This function splits the input edges based on the partition book and
...@@ -627,6 +667,10 @@ def edge_split(edges, partition_book, rank=None, force_even=False): ...@@ -627,6 +667,10 @@ def edge_split(edges, partition_book, rank=None, force_even=False):
The vector of edge Ids that belong to the rank. The vector of edge Ids that belong to the rank.
''' '''
num_edges = 0 num_edges = 0
if not isinstance(edges, DistTensor):
assert partition_book is not None, 'Regular tensor requires a partition book.'
elif partition_book is None:
partition_book = edges.part_policy.partition_book
for part in partition_book.metadata(): for part in partition_book.metadata():
num_edges += part['num_edges'] num_edges += part['num_edges']
assert len(edges) == num_edges, \ assert len(edges) == num_edges, \
......
...@@ -218,8 +218,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -218,8 +218,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
---------- ----------
g : DistGraph g : DistGraph
The distributed graph. The distributed graph.
nodes : tensor nodes : tensor or dict
Node ids to sample neighbors from. Node ids to sample neighbors from. If it's a dict, it should contain only
one key-value pair to make this API consistent with dgl.sampling.sample_neighbors.
fanout : int fanout : int
The number of sampled neighbors for each node. The number of sampled neighbors for each node.
edge_dir : str, optional edge_dir : str, optional
...@@ -238,6 +239,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -238,6 +239,9 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
``nodes``. The sampled subgraph has the same metagraph as the original ``nodes``. The sampled subgraph has the same metagraph as the original
one. one.
""" """
if isinstance(nodes, dict):
assert len(nodes) == 1, 'The distributed sampler only supports one node type for now.'
nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return SamplingRequest(node_ids, fanout, edge_dir=edge_dir, return SamplingRequest(node_ids, fanout, edge_dir=edge_dir,
prob=prob, replace=replace) prob=prob, replace=replace)
...@@ -267,6 +271,9 @@ def in_subgraph(g, nodes): ...@@ -267,6 +271,9 @@ def in_subgraph(g, nodes):
DGLHeteroGraph DGLHeteroGraph
The subgraph. The subgraph.
""" """
if isinstance(nodes, dict):
assert len(nodes) == 1, 'The distributed in_subgraph only supports one node type for now.'
nodes = list(nodes.values())[0]
def issue_remote_req(node_ids): def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids) return InSubgraphRequest(node_ids)
def local_access(local_g, partition_book, local_nids): def local_access(local_g, partition_book, local_nids):
......
...@@ -340,13 +340,21 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -340,13 +340,21 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
len(local_nodes), len(local_edges))) len(local_nodes), len(local_edges)))
tot_num_inner_edges += len(local_edges) tot_num_inner_edges += len(local_edges)
for name in g.ndata: for name in g.ndata:
if name in [NID, 'inner_node']:
continue
node_feats[name] = F.gather_row(g.ndata[name], local_nodes) node_feats[name] = F.gather_row(g.ndata[name], local_nodes)
for name in g.edata: for name in g.edata:
if name in [EID, 'inner_edge']:
continue
edge_feats[name] = F.gather_row(g.edata[name], local_edges) edge_feats[name] = F.gather_row(g.edata[name], local_edges)
else: else:
for name in g.ndata: for name in g.ndata:
if name in [NID, 'inner_node']:
continue
node_feats[name] = g.ndata[name] node_feats[name] = g.ndata[name]
for name in g.edata: for name in g.edata:
if name in [EID, 'inner_edge']:
continue
edge_feats[name] = g.edata[name] edge_feats[name] = g.edata[name]
part_dir = os.path.join(out_path, "part" + str(part_id)) part_dir = os.path.join(out_path, "part" + str(part_id))
......
...@@ -262,15 +262,15 @@ def test_split(): ...@@ -262,15 +262,15 @@ def test_split():
local_nids = F.nonzero_1d(part_g.ndata['inner_node']) local_nids = F.nonzero_1d(part_g.ndata['inner_node'])
local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids) local_nids = F.gather_row(part_g.ndata[dgl.NID], local_nids)
nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids)) nodes1 = np.intersect1d(selected_nodes, F.asnumpy(local_nids))
nodes2 = node_split(node_mask, gpb, i) nodes2 = node_split(node_mask, gpb, i, force_even=False)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2))) assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes2)))
local_nids = F.asnumpy(local_nids) local_nids = F.asnumpy(local_nids)
for n in nodes1: for n in nodes1:
assert n in local_nids assert n in local_nids
dgl.distributed.set_num_client(num_parts * 2) dgl.distributed.set_num_client(num_parts * 2)
nodes3 = node_split(node_mask, gpb, i * 2) nodes3 = node_split(node_mask, gpb, i * 2, force_even=False)
nodes4 = node_split(node_mask, gpb, i * 2 + 1) nodes4 = node_split(node_mask, gpb, i * 2 + 1, force_even=False)
nodes5 = F.cat([nodes3, nodes4], 0) nodes5 = F.cat([nodes3, nodes4], 0)
assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5))) assert np.all(np.sort(nodes1) == np.sort(F.asnumpy(nodes5)))
...@@ -278,15 +278,15 @@ def test_split(): ...@@ -278,15 +278,15 @@ def test_split():
local_eids = F.nonzero_1d(part_g.edata['inner_edge']) local_eids = F.nonzero_1d(part_g.edata['inner_edge'])
local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids) local_eids = F.gather_row(part_g.edata[dgl.EID], local_eids)
edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids)) edges1 = np.intersect1d(selected_edges, F.asnumpy(local_eids))
edges2 = edge_split(edge_mask, gpb, i) edges2 = edge_split(edge_mask, gpb, i, force_even=False)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2))) assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges2)))
local_eids = F.asnumpy(local_eids) local_eids = F.asnumpy(local_eids)
for e in edges1: for e in edges1:
assert e in local_eids assert e in local_eids
dgl.distributed.set_num_client(num_parts * 2) dgl.distributed.set_num_client(num_parts * 2)
edges3 = edge_split(edge_mask, gpb, i * 2) edges3 = edge_split(edge_mask, gpb, i * 2, force_even=False)
edges4 = edge_split(edge_mask, gpb, i * 2 + 1) edges4 = edge_split(edge_mask, gpb, i * 2 + 1, force_even=False)
edges5 = F.cat([edges3, edges4], 0) edges5 = F.cat([edges3, edges4], 0)
assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5))) assert np.all(np.sort(edges1) == np.sort(F.asnumpy(edges5)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment