Unverified Commit bd1e48a5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug] Fix khop graph (#1433)

* Update

* Update

* Update
parent 88c34487
......@@ -220,6 +220,18 @@ def khop_graph(g, k):
Examples
--------
Below gives an easy example:
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1], [1, 2])
>>> g_2 = dgl.transform.khop_graph(g, 2)
>>> print(g_2.edges())
(tensor([0]), tensor([2]))
A more complicated example:
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
......@@ -234,7 +246,7 @@ def khop_graph(g, k):
edata_schemes={})
"""
n = g.number_of_nodes()
adj_k = g.adjacency_matrix_scipy(return_edge_ids=False) ** k
adj_k = g.adjacency_matrix_scipy(transpose=True, return_edge_ids=False) ** k
adj_k = adj_k.tocoo()
multiplicity = adj_k.data
row = np.repeat(adj_k.row, multiplicity)
......
......@@ -130,20 +130,27 @@ def test_bidirected_graph():
def test_khop_graph():
N = 20
feat = F.randn((N, 5))
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
def _test(g):
for k in range(4):
g_k = dgl.khop_graph(g, k)
# use original graph to do message passing for k times.
g.ndata['h'] = feat
for _ in range(k):
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_0 = g.ndata.pop('h')
# use k-hop graph to do message passing for one time.
g_k.ndata['h'] = feat
g_k.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
h_1 = g_k.ndata.pop('h')
assert F.allclose(h_0, h_1, rtol=1e-3, atol=1e-3)
# Test for random undirected graphs
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3))
_test(g)
# Test for random directed graphs
g = dgl.DGLGraph(nx.erdos_renyi_graph(N, 0.3, directed=True))
_test(g)
def test_khop_adj():
N = 20
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment