Unverified Commit 7f83d745 authored by Andrew's avatar Andrew Committed by GitHub
Browse files

[Feature] method for merging graphs into graphs (#3522)



* Added graph updating method and tests. resolves #3488

* removed spaces around named args

* customizing indices for graph's idtype and ctx

* changing torch ops to generic backend ops

* changing tensors to np arrays

* created dgl merge function and tests

* Changed per-graph edge updates to single ag update

* removed update method and tests

* reformat newlines & spaces

* concatenating in one-shot instead of iteratively
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 4b295d60
......@@ -33,6 +33,7 @@ from .convert import *
from .generators import *
from .heterograph import DGLHeteroGraph
from .heterograph import DGLHeteroGraph as DGLGraph # pylint: disable=reimported
from .merge import *
from .subgraph import *
from .traversal import *
from .transform import *
......
"""Utilities for merging graphs."""
import dgl
from . import backend as F
from .base import DGLError
__all__ = ['merge']
def merge(graphs):
r"""Merge a sequence of graphs together into a single graph.
Nodes and edges that exist in ``graphs[i+1]`` but not in ``dgl.merge(graphs[0:i+1])``
will be added to ``dgl.merge(graphs[0:i+1])`` along with their data.
Nodes that exist in both ``dgl.merge(graphs[0:i+1])`` and ``graphs[i+1]``
will be updated with ``graphs[i+1]``'s data if they do not match.
Parameters
----------
graphs : list[DGLGraph]
Input graphs.
Notes
----------
* Inplace updates are applied to a new, empty graph.
* Features that exist in ``dgl.graphs[i+1]`` will be created in
``dgl.merge(dgl.graphs[i+1])`` if they do not already exist.
Examples
----------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
>>> g = dgl.graph((torch.tensor([0,1]), torch.tensor([2,3])))
>>> g.ndata["x"] = torch.zeros(4)
>>> h = dgl.graph((torch.tensor([1,2]), torch.tensor([0,4])))
>>> h.ndata["x"] = torch.ones(5)
>>> m = dgl.merge([g,h])
``m`` now contains edges and nodes from ``h`` and ``g``.
>>> m.edges()
(tensor([0, 1, 1, 2]), tensor([2, 3, 0, 4]))
>>> m.nodes()
tensor([0, 1, 2, 3, 4])
``g``'s data has updated with ``h``'s in ``m``.
>>> m.ndata["x"]
tensor([1., 1., 1., 1., 1.])
See Also
----------
add_nodes
add_edges
"""
if len(graphs) == 0:
raise DGLError('The input list of graphs cannot be empty.')
ref = graphs[0]
ntypes = ref.ntypes
etypes = ref.canonical_etypes
data_dict = {etype: ([], []) for etype in etypes}
num_nodes_dict = {ntype: 0 for ntype in ntypes}
merged = dgl.heterograph(data_dict, num_nodes_dict, ref.idtype, ref.device)
# Merge edges and edge data.
for etype in etypes:
unmerged_us = []
unmerged_vs = []
edata_frames = []
for graph in graphs:
etype_id = graph.get_etype_id(etype)
us, vs = graph.edges(etype=etype)
unmerged_us.append(us)
unmerged_vs.append(vs)
edge_data = graph._edge_frames[etype_id]
edata_frames.append(edge_data)
keys = ref.edges[etype].data.keys()
if len(keys) == 0:
edges_data = None
else:
edges_data = {k: F.cat([f[k] for f in edata_frames], dim=0) for k in keys}
merged_us = F.copy_to(F.astype(F.cat(unmerged_us, dim=0), ref.idtype), ref.device)
merged_vs = F.copy_to(F.astype(F.cat(unmerged_vs, dim=0), ref.idtype), ref.device)
merged.add_edges(merged_us, merged_vs, edges_data, etype)
# Add node data and isolated nodes from next_graph to merged.
for next_graph in graphs:
for ntype in ntypes:
merged_ntype_id = merged.get_ntype_id(ntype)
next_ntype_id = next_graph.get_ntype_id(ntype)
next_ndata = next_graph._node_frames[next_ntype_id]
node_diff = (next_graph.num_nodes(ntype=ntype) -
merged.num_nodes(ntype=ntype))
n_extra_nodes = max(0, node_diff)
merged.add_nodes(n_extra_nodes, ntype=ntype)
next_nodes = F.arange(
0, next_graph.num_nodes(ntype=ntype), merged.idtype, merged.device
)
merged._node_frames[merged_ntype_id].update_row(
next_nodes, next_ndata
)
return merged
import backend as F
from test_utils import parametrize_dtype
import dgl
@parametrize_dtype
def test_heterograph_merge(idtype):
g1 = dgl.heterograph({("a", "to", "b"): ([0,1], [1,0])}).astype(idtype).to(F.ctx())
g1_n_edges = g1.num_edges(etype="to")
g1.nodes["a"].data["nh"] = F.randn((2,3))
g1.nodes["b"].data["nh"] = F.randn((2,3))
g1.edges["to"].data["eh"] = F.randn((2,3))
g2 = dgl.heterograph({("a", "to", "b"): ([1,2,3], [2,3,5])}).astype(idtype).to(F.ctx())
g2.nodes["a"].data["nh"] = F.randn((4,3))
g2.nodes["b"].data["nh"] = F.randn((6,3))
g2.edges["to"].data["eh"] = F.randn((3,3))
g2.add_nodes(3, ntype="a")
g2.add_nodes(3, ntype="b")
m = dgl.merge([g1, g2])
# Check g2's edges and nodes were added to g1's in m.
m_us = F.asnumpy(m.edges()[0][g1_n_edges:])
g2_us = F.asnumpy(g2.edges()[0])
assert all(m_us == g2_us)
m_vs = F.asnumpy(m.edges()[1][g1_n_edges:])
g2_vs = F.asnumpy(g2.edges()[1])
assert all(m_vs == g2_vs)
for ntype in m.ntypes:
assert m.num_nodes(ntype=ntype) == max(
g1.num_nodes(ntype=ntype), g2.num_nodes(ntype=ntype)
)
# Check g1's node data was updated with g2's in m.
for key in m.nodes[ntype].data:
g2_n_nodes = g2.num_nodes(ntype=ntype)
updated_g1_ndata = F.asnumpy(m.nodes[ntype].data[key][:g2_n_nodes])
g2_ndata = F.asnumpy(g2.nodes[ntype].data[key])
assert all(
(updated_g1_ndata == g2_ndata).flatten()
)
# Check g1's edge data was updated with g2's in m.
for key in m.edges["to"].data:
updated_g1_edata = F.asnumpy(m.edges["to"].data[key][g1_n_edges:])
g2_edata = F.asnumpy(g2.edges["to"].data[key])
assert all(
(updated_g1_edata == g2_edata).flatten()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment