Unverified Commit 56b5d0e5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Transform] Allow Disabling Adding Reverse Edges for Self Loops (#3701)

* Update

* Update

* Update
parent 92db4bd5
...@@ -714,13 +714,16 @@ def to_bidirected(g, copy_ndata=False, readonly=None): ...@@ -714,13 +714,16 @@ def to_bidirected(g, copy_ndata=False, readonly=None):
return g return g
def add_reverse_edges(g, readonly=None, copy_ndata=True, def add_reverse_edges(g, readonly=None, copy_ndata=True,
copy_edata=False, ignore_bipartite=False): copy_edata=False, ignore_bipartite=False, exclude_self=True):
r"""Add a reversed edge for each edge in the input graph and return a new graph. r"""Add a reversed edge for each edge in the input graph and return a new graph.
For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this For a graph with edges :math:`(i_1, j_1), \cdots, (i_n, j_n)`, this
function creates a new graph with edges function creates a new graph with edges
:math:`(i_1, j_1), \cdots, (i_n, j_n), (j_1, i_1), \cdots, (j_n, i_n)`. :math:`(i_1, j_1), \cdots, (i_n, j_n), (j_1, i_1), \cdots, (j_n, i_n)`.
The returned graph may have duplicate edges. To create a bidirected graph without
duplicate edges, use :func:`to_bidirected`.
The operation only works for edges whose two endpoints belong to the same node type. The operation only works for edges whose two endpoints belong to the same node type.
DGL will raise error if the input graph is heterogeneous and contains edges DGL will raise error if the input graph is heterogeneous and contains edges
with different types of endpoints. If :attr:`ignore_bipartite` is true, DGL will with different types of endpoints. If :attr:`ignore_bipartite` is true, DGL will
...@@ -750,6 +753,9 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -750,6 +753,9 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
no error is raised. If False, an error will be raised if no error is raised. If False, an error will be raised if
an edge type of the input heterogeneous graph is for a unidirectional an edge type of the input heterogeneous graph is for a unidirectional
bipartite graph. bipartite graph.
exclude_self: bool, optional
If True, it does not add reverse edges for self-loops, which is likely
meaningless in most cases.
Returns Returns
------- -------
...@@ -812,32 +818,45 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -812,32 +818,45 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
# get node cnt for each ntype # get node cnt for each ntype
num_nodes_dict = {} num_nodes_dict = {}
for ntype in g.ntypes: for ntype in g.ntypes:
num_nodes_dict[ntype] = g.number_of_nodes(ntype) num_nodes_dict[ntype] = g.num_nodes(ntype)
canonical_etypes = g.canonical_etypes canonical_etypes = g.canonical_etypes
num_nodes_dict = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes} num_nodes_dict = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
subgs = {}
rev_eids = {}
def add_for_etype(etype):
u, v = g.edges(form='uv', order='eid', etype=etype)
rev_u, rev_v = v, u
eid = F.copy_to(F.arange(0, g.num_edges(etype)), g.device)
if exclude_self:
self_loop_mask = F.equal(rev_u, rev_v)
non_self_loop_mask = F.logical_not(self_loop_mask)
rev_u = F.boolean_mask(rev_u, non_self_loop_mask)
rev_v = F.boolean_mask(rev_v, non_self_loop_mask)
non_self_loop_eid = F.boolean_mask(eid, non_self_loop_mask)
rev_eids[etype] = F.cat([eid, non_self_loop_eid], 0)
else:
rev_eids[etype] = F.cat([eid, eid], 0)
subgs[etype] = (F.cat([u, rev_u], dim=0), F.cat([v, rev_v], dim=0))
# fast path # fast path
if ignore_bipartite is False: if ignore_bipartite is False:
subgs = {}
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
assert False, "add_reverse_edges is not well defined for " \ assert False, "add_reverse_edges is not well defined for " \
"unidirectional bipartite graphs" \ "unidirectional bipartite graphs" \
", but {} is unidirectional bipartite".format(c_etype) ", but {} is unidirectional bipartite".format(c_etype)
add_for_etype(c_etype)
u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict) new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)
else: else:
subgs = {}
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
u, v = g.edges(form='uv', order='eid', etype=c_etype) u, v = g.edges(form='uv', order='eid', etype=c_etype)
subgs[c_etype] = (u, v) subgs[c_etype] = (u, v)
else: else:
u, v = g.edges(form='uv', order='eid', etype=c_etype) add_for_etype(c_etype)
subgs[c_etype] = (F.cat([u, v], dim=0), F.cat([v, u], dim=0))
new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict) new_g = convert.heterograph(subgs, num_nodes_dict=num_nodes_dict)
...@@ -850,11 +869,10 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True, ...@@ -850,11 +869,10 @@ def add_reverse_edges(g, readonly=None, copy_ndata=True,
# find indices # find indices
eids = [] eids = []
for c_etype in canonical_etypes: for c_etype in canonical_etypes:
eid = F.copy_to(F.arange(0, g.number_of_edges(c_etype)), new_g.device)
if c_etype[0] != c_etype[2]: if c_etype[0] != c_etype[2]:
eids.append(eid) eids.append(F.copy_to(F.arange(0, g.number_of_edges(c_etype)), new_g.device))
else: else:
eids.append(F.cat([eid, eid], 0)) eids.append(rev_eids[c_etype])
edge_frames = utils.extract_edge_subframes(g, eids) edge_frames = utils.extract_edge_subframes(g, eids)
utils.set_new_frames(new_g, edge_frames=edge_frames) utils.set_new_frames(new_g, edge_frames=edge_frames)
......
...@@ -320,7 +320,7 @@ def test_add_reverse_edges(): ...@@ -320,7 +320,7 @@ def test_add_reverse_edges():
# zero edge graph # zero edge graph
g = dgl.graph(([], [])) g = dgl.graph(([], []))
bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True) bg = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True, exclude_self=False)
# heterogeneous graph # heterogeneous graph
g = dgl.heterograph({ g = dgl.heterograph({
...@@ -401,6 +401,16 @@ def test_add_reverse_edges(): ...@@ -401,6 +401,16 @@ def test_add_reverse_edges():
assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0), assert F.array_equal(F.cat([g.edges['wins'].data['h'], g.edges['wins'].data['h']], dim=0),
bg.edges['wins'].data['h']) bg.edges['wins'].data['h'])
# test exclude_self
g = dgl.heterograph({
('A', 'r1', 'A'): (F.tensor([0, 0, 1, 1]), F.tensor([0, 1, 1, 2])),
('A', 'r2', 'A'): (F.tensor([0, 1]), F.tensor([1, 2]))
})
g.edges['r1'].data['h'] = F.tensor([0, 1, 2, 3])
rg = dgl.add_reverse_edges(g, copy_edata=True, exclude_self=True)
assert rg.num_edges('r1') == 6
assert rg.num_edges('r2') == 4
assert F.array_equal(rg.edges['r1'].data['h'], F.tensor([0, 1, 2, 3, 1, 3]))
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_simple_graph(): def test_simple_graph():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment