Unverified Commit 35c9473b authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[HotFix] Fix add_reverse_edges (#1960)



* Hot fix

* more test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
parent 3d837706
......@@ -329,6 +329,11 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
dgl_warning("Parameter readonly is deprecated" \
"There will be no difference between readonly and non-readonly DGLGraph")
# get node cnt for each ntype
num_nodes_dict = {}
for ntype in g.ntypes:
num_nodes_dict[ntype] = g.number_of_nodes(ntype)
canonical_etypes = g.canonical_etypes
# fast path
if ignore_bipartite is False:
......@@ -342,7 +347,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = convert.heterograph(subgs)
new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)
else:
subgs = {}
for c_etype in canonical_etypes:
......@@ -353,7 +358,7 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = convert.heterograph(subgs)
new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)
# handle features
if copy_ndata:
......
......@@ -362,6 +362,36 @@ def test_add_reverse_edges():
assert F.array_equal(u, ub)
assert F.array_equal(v, vb)
# test the case when some nodes have zero degree
# homogeneous graph
g = dgl.graph((F.tensor([0, 1, 3, 1]), F.tensor([1, 2, 0, 2])), num_nodes=6)
g.ndata['h'] = F.tensor([[0.], [1.], [2.], [1.], [1.], [1.]])
g.edata['h'] = F.tensor([[3.], [4.], [5.], [6.]])
bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True)
assert g.number_of_nodes() == bg.number_of_nodes()
assert F.array_equal(g.ndata['h'], bg.ndata['h'])
assert F.array_equal(F.cat([g.edata['h'], g.edata['h']], dim=0), bg.edata['h'])
# heterogeneous graph
g = dgl.heterograph({
('user', 'wins', 'user'): (F.tensor([0, 2, 0, 2, 2]), F.tensor([1, 1, 2, 1, 0])),
('user', 'plays', 'game'): (F.tensor([1, 2, 1]), F.tensor([2, 1, 1])),
('user', 'follows', 'user'): (F.tensor([1, 2, 1]), F.tensor([0, 0, 0]))},
num_nodes_dict={
'user': 5,
'game': 3
})
g.nodes['game'].data['hv'] = F.ones((3, 1))
g.nodes['user'].data['hv'] = F.ones((5, 1))
g.edges['wins'].data['h'] = F.tensor([0, 1, 2, 3, 4])
bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, ignore_bipartite=True)
assert g.number_of_nodes('user') == bg.number_of_nodes('user')
assert g.number_of_nodes('game') == bg.number_of_nodes('game')
assert F.array_equal(g.nodes['game'].data['hv'], bg.nodes['game'].data['hv'])
assert F.array_equal(g.nodes['user'].data['hv'], bg.nodes['user'].data['hv'])
assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0),
bg.edges['wins'].data['h'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_simple_graph():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment