Unverified Commit a9d6f770 authored by Nicola Vitucci's avatar Nicola Vitucci Committed by GitHub
Browse files

[Feature] Support heterogeneous graphs in the `to_networkx` method (#5726)


Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent bb43d042
...@@ -1643,8 +1643,108 @@ def bipartite_from_networkx( ...@@ -1643,8 +1643,108 @@ def bipartite_from_networkx(
return g.to(device) return g.to(device)
def to_networkx(g, node_attrs=None, edge_attrs=None): def _to_networkx_homogeneous(g, node_attrs, edge_attrs):
"""Convert a homogeneous graph to a NetworkX graph and return. # TODO: consider adding an eid_attr parameter as in
# `_to_networkx_heterogeneous` when this function is properly tested
# (see GitHub issue #5735)
src, dst = g.edges()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
# xiangsx: Always treat graph as multigraph
nx_graph = nx.MultiDiGraph()
nx_graph.add_nodes_from(range(g.num_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid)
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
feat_dict = g._get_n_repr(0, nid)
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in node_attrs}
)
if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True):
eid = attr["id"]
feat_dict = g._get_e_repr(0, eid)
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}
)
return nx_graph
def _to_networkx_heterogeneous(
g, node_attrs, edge_attrs, ntype_attr, etype_attr, eid_attr
):
nx_graph = nx.MultiDiGraph()
# This implementation does not use `ndata` and `edata` in the call to
# `to_homogeneous` because the function expects node and edge attributes
# both to be defined for every type and to have the same shape.
# If the `to_homogeneous` function is updated to support non-uniform node
# and edge attributes, the implementation can be simplified.
hom_g = to_homogeneous(g, store_type=True, return_count=False)
ntypes = g.ntypes
etypes = g.canonical_etypes
for hom_nid, ndata in enumerate(zip(hom_g.ndata[NID], hom_g.ndata[NTYPE])):
orig_nid, ntype = ndata
attrs = {ntype_attr: ntypes[ntype]}
if node_attrs is not None:
assert ntype_attr not in node_attrs, (
f"'{ntype_attr}' already used as node type attribute, "
f"please provide a different value for ntype_attr"
)
feat_dict = g._get_n_repr(ntype, orig_nid)
attrs.update(
{
key: F.squeeze(feat_dict[key], 0)
for key in node_attrs
if key in feat_dict
}
)
nx_graph.add_node(hom_nid, **attrs)
for hom_eid, edata in enumerate(zip(hom_g.edata[EID], hom_g.edata[ETYPE])):
orig_eid, etype = edata
attrs = {eid_attr: hom_eid, etype_attr: etypes[etype]}
if edge_attrs is not None:
assert etype_attr not in edge_attrs, (
f"'{etype_attr}' already used as edge type attribute, "
f"please provide a different value for etype_attr"
)
assert eid_attr not in edge_attrs, (
f"'{eid_attr}' already used as edge ID attribute, "
f"please provide a different value for eid_attr"
)
feat_dict = g._get_e_repr(etype, orig_eid)
attrs.update(
{
key: F.squeeze(feat_dict[key], 0)
for key in edge_attrs
if key in feat_dict
}
)
src, dst = hom_g.find_edges(hom_eid)
nx_graph.add_edge(int(src), int(dst), **attrs)
return nx_graph
def to_networkx(
g,
node_attrs=None,
edge_attrs=None,
ntype_attr="ntype",
etype_attr="etype",
eid_attr="id",
):
"""Convert a graph to a NetworkX graph and return.
The resulting NetworkX graph also contains the node/edge features of the input graph. The resulting NetworkX graph also contains the node/edge features of the input graph.
Additionally, DGL saves the edge IDs as the ``'id'`` edge attribute in the Additionally, DGL saves the edge IDs as the ``'id'`` edge attribute in the
...@@ -1653,11 +1753,21 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): ...@@ -1653,11 +1753,21 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
Parameters Parameters
---------- ----------
g : DGLGraph g : DGLGraph
A homogeneous graph. A homogeneous or heterogeneous graph.
node_attrs : iterable of str, optional node_attrs : iterable of str, optional
The node attributes to copy from ``g.ndata``. (Default: None) The node attributes to copy from ``g.ndata``. (Default: None)
edge_attrs : iterable of str, optional edge_attrs : iterable of str, optional
The edge attributes to copy from ``g.edata``. (Default: None) The edge attributes to copy from ``g.edata``.
(Default: None)
ntype_attr : str, optional
The name of the node attribute to store the node types in the NetworkX object.
(Default: "ntype")
etype_attr : str, optional
The name of the edge attribute to store the edge canonical types in the NetworkX object.
(Default: "etype")
eid_attr : str, optional
The name of the edge attribute to store the original edge ID in the NetworkX object.
(Default: "id")
Returns Returns
------- -------
...@@ -1670,54 +1780,82 @@ def to_networkx(g, node_attrs=None, edge_attrs=None): ...@@ -1670,54 +1780,82 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
Examples Examples
-------- --------
The following example uses PyTorch backend. The following examples use the PyTorch backend.
>>> import dgl >>> import dgl
>>> import torch >>> import torch
With a homogeneous graph:
>>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3]))) >>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3])))
>>> g.ndata['h'] = torch.zeros(4, 1) >>> g.ndata['h'] = torch.zeros(4, 1)
>>> g.edata['h1'] = torch.ones(2, 1) >>> g.edata['h1'] = torch.ones(2, 1)
>>> g.edata['h2'] = torch.zeros(2, 2) >>> g.edata['h2'] = torch.zeros(2, 2)
>>> nx_g = dgl.to_networkx(g, node_attrs=['h'], edge_attrs=['h1', 'h2']) >>> nx_g = dgl.to_networkx(g, node_attrs=['h'], edge_attrs=['h1', 'h2'])
>>> nx_g.nodes(data=True) >>> nx_g.nodes(data=True)
NodeDataView({0: {'h': tensor([0.])}, NodeDataView({
1: {'h': tensor([0.])}, 0: {'h': tensor([0.])},
2: {'h': tensor([0.])}, 1: {'h': tensor([0.])},
3: {'h': tensor([0.])}}) 2: {'h': tensor([0.])},
3: {'h': tensor([0.])}
})
>>> nx_g.edges(data=True) >>> nx_g.edges(data=True)
OutMultiEdgeDataView([(1, 1, {'id': 0, 'h1': tensor([1.]), 'h2': tensor([0., 0.])}), OutMultiEdgeDataView([
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})]) (1, 1, {'id': 0, 'h1': tensor([1.]), 'h2': tensor([0., 0.])}),
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})
])
With a heterogeneous graph:
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
... ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
... ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
... })
... g.ndata['n'] = {
... 'game': torch.zeros(5, 1),
... 'user': torch.ones(4, 1)
... }
... g.edata['e'] = {
... ('user', 'follows', 'user'): torch.zeros(2, 1),
... 'plays': torch.ones(2, 1)
... }
>>> nx_g = dgl.to_networkx(g, node_attrs=['n'], edge_attrs=['e'])
>>> nx_g.nodes(data=True)
NodeDataView({
0: {'ntype': 'game', 'n': tensor([0.])},
1: {'ntype': 'game', 'n': tensor([0.])},
2: {'ntype': 'game', 'n': tensor([0.])},
3: {'ntype': 'game', 'n': tensor([0.])},
4: {'ntype': 'game', 'n': tensor([0.])},
5: {'ntype': 'topic'},
6: {'ntype': 'topic'},
7: {'ntype': 'topic'},
8: {'ntype': 'user', 'n': tensor([1.])},
9: {'ntype': 'user', 'n': tensor([1.])},
10: {'ntype': 'user', 'n': tensor([1.])},
11: {'ntype': 'user', 'n': tensor([1.])}
})
>>> nx_g.edges(data=True)
OutMultiEdgeDataView([
(8, 9, {'id': 2, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),
(8, 3, {'id': 4, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])}),
(9, 6, {'id': 0, 'etype': ('user', 'follows', 'topic')}),
(9, 7, {'id': 1, 'etype': ('user', 'follows', 'topic')}),
(9, 10, {'id': 3, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),
(11, 4, {'id': 5, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])})
])
""" """
if g.device != F.cpu(): if g.device != F.cpu():
raise DGLError( raise DGLError(
"Cannot convert a CUDA graph to networkx. Call g.cpu() first." "Cannot convert a CUDA graph to networkx. Call g.cpu() first."
) )
if not g.is_homogeneous: if g.is_homogeneous:
raise DGLError("dgl.to_networkx only supports homogeneous graphs.") return _to_networkx_homogeneous(g, node_attrs, edge_attrs)
src, dst = g.edges() else:
src = F.asnumpy(src) return _to_networkx_heterogeneous(
dst = F.asnumpy(dst) g, node_attrs, edge_attrs, ntype_attr, etype_attr, eid_attr
# xiangsx: Always treat graph as multigraph )
nx_graph = nx.MultiDiGraph()
nx_graph.add_nodes_from(range(g.num_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid)
if node_attrs is not None:
for nid, attr in nx_graph.nodes(data=True):
feat_dict = g._get_n_repr(0, nid)
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in node_attrs}
)
if edge_attrs is not None:
for _, _, attr in nx_graph.edges(data=True):
eid = attr["id"]
feat_dict = g._get_e_repr(0, eid)
attr.update(
{key: F.squeeze(feat_dict[key], 0) for key in edge_attrs}
)
return nx_graph
DGLGraph.to_networkx = to_networkx DGLGraph.to_networkx = to_networkx
......
import unittest
import backend as F
import dgl
from utils import parametrize_idtype
def get_nodes_by_ntype(nodes, ntype):
return dict((k, v) for k, v in nodes.items() if v["ntype"] == ntype)
def edge_attrs(edge):
# Edges in Networkx are in the format (src, dst, attrs)
return edge[2]
def get_edges_by_etype(edges, etype):
return [e for e in edges if edge_attrs(e)["etype"] == etype]
def check_attrs_for_nodes(nodes, attrs):
return all(v.keys() == attrs for v in nodes.values())
def check_attr_values_for_nodes(nodes, attr_name, values):
return F.allclose(
F.stack([v[attr_name] for v in nodes.values()], 0), values
)
def check_attrs_for_edges(edges, attrs):
return all(edge_attrs(e).keys() == attrs for e in edges)
def check_attr_values_for_edges(edges, attr_name, values):
return F.allclose(
F.stack([edge_attrs(e)[attr_name] for e in edges], 0), values
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="`to_networkx` does not support graphs on GPU",
)
@parametrize_idtype
def test_to_networkx(idtype):
# TODO: adapt and move code from the _test_nx_conversion function in
# tests/python/common/function/test_basics.py to here
# (pending resolution of https://github.com/dmlc/dgl/issues/5735).
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "follows", "topic"): ([1, 1], [1, 2]),
("user", "plays", "game"): ([0, 3], [3, 4]),
},
idtype=idtype,
device=F.ctx(),
)
n1 = F.randn((5, 3))
n2 = F.randn((4, 2))
e1 = F.randn((2, 3))
e2 = F.randn((2, 2))
g.nodes["game"].data["n"] = F.copy_to(n1, ctx=F.ctx())
g.nodes["user"].data["n"] = F.copy_to(n2, ctx=F.ctx())
g.edges[("user", "follows", "user")].data["e"] = F.copy_to(e1, ctx=F.ctx())
g.edges["plays"].data["e"] = F.copy_to(e2, ctx=F.ctx())
nxg = dgl.to_networkx(
g,
node_attrs=["n"],
edge_attrs=["e"],
)
# Test nodes
nxg_nodes = dict(nxg.nodes(data=True))
assert len(nxg_nodes) == g.num_nodes()
assert {v["ntype"] for v in nxg_nodes.values()} == set(g.ntypes)
nxg_nodes_by_ntype = {}
for ntype in g.ntypes:
nxg_nodes_by_ntype[ntype] = get_nodes_by_ntype(nxg_nodes, ntype)
assert g.num_nodes(ntype) == len(nxg_nodes_by_ntype[ntype])
assert check_attrs_for_nodes(nxg_nodes_by_ntype["game"], {"ntype", "n"})
assert check_attr_values_for_nodes(nxg_nodes_by_ntype["game"], "n", n1)
assert check_attrs_for_nodes(nxg_nodes_by_ntype["user"], {"ntype", "n"})
assert check_attr_values_for_nodes(nxg_nodes_by_ntype["user"], "n", n2)
# Nodes without node attributes
assert check_attrs_for_nodes(nxg_nodes_by_ntype["topic"], {"ntype"})
# Test edges
nxg_edges = list(nxg.edges(data=True))
assert len(nxg_edges) == g.num_edges()
assert {edge_attrs(e)["etype"] for e in nxg_edges} == set(
g.canonical_etypes
)
nxg_edges_by_etype = {}
for etype in g.canonical_etypes:
nxg_edges_by_etype[etype] = get_edges_by_etype(nxg_edges, etype)
assert g.num_edges(etype) == len(nxg_edges_by_etype[etype])
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "follows", "user")],
{"id", "etype", "e"},
)
assert check_attr_values_for_edges(
nxg_edges_by_etype[("user", "follows", "user")], "e", e1
)
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "plays", "game")], {"id", "etype", "e"}
)
assert check_attr_values_for_edges(
nxg_edges_by_etype[("user", "plays", "game")], "e", e2
)
# Edges without edge attributes
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "follows", "topic")], {"id", "etype"}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment