Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a9d6f770
Unverified
Commit
a9d6f770
authored
Jun 01, 2023
by
Nicola Vitucci
Committed by
GitHub
Jun 01, 2023
Browse files
[Feature] Support heterogeneous graphs in the `to_networkx` method (#5726)
Co-authored-by:
Mufei Li
<
mufeili1996@gmail.com
>
parent
bb43d042
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
296 additions
and
36 deletions
+296
-36
python/dgl/convert.py
python/dgl/convert.py
+174
-36
tests/python/common/test_convert.py
tests/python/common/test_convert.py
+122
-0
No files found.
python/dgl/convert.py
View file @
a9d6f770
...
@@ -1643,8 +1643,108 @@ def bipartite_from_networkx(
...
@@ -1643,8 +1643,108 @@ def bipartite_from_networkx(
return
g
.
to
(
device
)
return
g
.
to
(
device
)
def
to_networkx
(
g
,
node_attrs
=
None
,
edge_attrs
=
None
):
def
_to_networkx_homogeneous
(
g
,
node_attrs
,
edge_attrs
):
"""Convert a homogeneous graph to a NetworkX graph and return.
# TODO: consider adding an eid_attr parameter as in
# `_to_networkx_heterogeneous` when this function is properly tested
# (see GitHub issue #5735)
src
,
dst
=
g
.
edges
()
src
=
F
.
asnumpy
(
src
)
dst
=
F
.
asnumpy
(
dst
)
# xiangsx: Always treat graph as multigraph
nx_graph
=
nx
.
MultiDiGraph
()
nx_graph
.
add_nodes_from
(
range
(
g
.
num_nodes
()))
for
eid
,
(
u
,
v
)
in
enumerate
(
zip
(
src
,
dst
)):
nx_graph
.
add_edge
(
u
,
v
,
id
=
eid
)
if
node_attrs
is
not
None
:
for
nid
,
attr
in
nx_graph
.
nodes
(
data
=
True
):
feat_dict
=
g
.
_get_n_repr
(
0
,
nid
)
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
node_attrs
}
)
if
edge_attrs
is
not
None
:
for
_
,
_
,
attr
in
nx_graph
.
edges
(
data
=
True
):
eid
=
attr
[
"id"
]
feat_dict
=
g
.
_get_e_repr
(
0
,
eid
)
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
edge_attrs
}
)
return
nx_graph
def
_to_networkx_heterogeneous
(
g
,
node_attrs
,
edge_attrs
,
ntype_attr
,
etype_attr
,
eid_attr
):
nx_graph
=
nx
.
MultiDiGraph
()
# This implementation does not use `ndata` and `edata` in the call to
# `to_homogeneous` because the function expects node and edge attributes
# both to be defined for every type and to have the same shape.
# If the `to_homogeneous` function is updated to support non-uniform node
# and edge attributes, the implementation can be simplified.
hom_g
=
to_homogeneous
(
g
,
store_type
=
True
,
return_count
=
False
)
ntypes
=
g
.
ntypes
etypes
=
g
.
canonical_etypes
for
hom_nid
,
ndata
in
enumerate
(
zip
(
hom_g
.
ndata
[
NID
],
hom_g
.
ndata
[
NTYPE
])):
orig_nid
,
ntype
=
ndata
attrs
=
{
ntype_attr
:
ntypes
[
ntype
]}
if
node_attrs
is
not
None
:
assert
ntype_attr
not
in
node_attrs
,
(
f
"'
{
ntype_attr
}
' already used as node type attribute, "
f
"please provide a different value for ntype_attr"
)
feat_dict
=
g
.
_get_n_repr
(
ntype
,
orig_nid
)
attrs
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
node_attrs
if
key
in
feat_dict
}
)
nx_graph
.
add_node
(
hom_nid
,
**
attrs
)
for
hom_eid
,
edata
in
enumerate
(
zip
(
hom_g
.
edata
[
EID
],
hom_g
.
edata
[
ETYPE
])):
orig_eid
,
etype
=
edata
attrs
=
{
eid_attr
:
hom_eid
,
etype_attr
:
etypes
[
etype
]}
if
edge_attrs
is
not
None
:
assert
etype_attr
not
in
edge_attrs
,
(
f
"'
{
etype_attr
}
' already used as edge type attribute, "
f
"please provide a different value for etype_attr"
)
assert
eid_attr
not
in
edge_attrs
,
(
f
"'
{
eid_attr
}
' already used as edge ID attribute, "
f
"please provide a different value for eid_attr"
)
feat_dict
=
g
.
_get_e_repr
(
etype
,
orig_eid
)
attrs
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
edge_attrs
if
key
in
feat_dict
}
)
src
,
dst
=
hom_g
.
find_edges
(
hom_eid
)
nx_graph
.
add_edge
(
int
(
src
),
int
(
dst
),
**
attrs
)
return
nx_graph
def
to_networkx
(
g
,
node_attrs
=
None
,
edge_attrs
=
None
,
ntype_attr
=
"ntype"
,
etype_attr
=
"etype"
,
eid_attr
=
"id"
,
):
"""Convert a graph to a NetworkX graph and return.
The resulting NetworkX graph also contains the node/edge features of the input graph.
The resulting NetworkX graph also contains the node/edge features of the input graph.
Additionally, DGL saves the edge IDs as the ``'id'`` edge attribute in the
Additionally, DGL saves the edge IDs as the ``'id'`` edge attribute in the
...
@@ -1653,11 +1753,21 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
...
@@ -1653,11 +1753,21 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
Parameters
Parameters
----------
----------
g : DGLGraph
g : DGLGraph
A homogeneous graph.
A homogeneous
or heterogeneous
graph.
node_attrs : iterable of str, optional
node_attrs : iterable of str, optional
The node attributes to copy from ``g.ndata``. (Default: None)
The node attributes to copy from ``g.ndata``. (Default: None)
edge_attrs : iterable of str, optional
edge_attrs : iterable of str, optional
The edge attributes to copy from ``g.edata``. (Default: None)
The edge attributes to copy from ``g.edata``.
(Default: None)
ntype_attr : str, optional
The name of the node attribute to store the node types in the NetworkX object.
(Default: "ntype")
etype_attr : str, optional
The name of the edge attribute to store the edge canonical types in the NetworkX object.
(Default: "etype")
eid_attr : str, optional
The name of the edge attribute to store the original edge ID in the NetworkX object.
(Default: "id")
Returns
Returns
-------
-------
...
@@ -1670,54 +1780,82 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
...
@@ -1670,54 +1780,82 @@ def to_networkx(g, node_attrs=None, edge_attrs=None):
Examples
Examples
--------
--------
The following example use
s
PyTorch backend.
The following example
s
use
the
PyTorch backend.
>>> import dgl
>>> import dgl
>>> import torch
>>> import torch
With a homogeneous graph:
>>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3])))
>>> g = dgl.graph((torch.tensor([1, 2]), torch.tensor([1, 3])))
>>> g.ndata['h'] = torch.zeros(4, 1)
>>> g.ndata['h'] = torch.zeros(4, 1)
>>> g.edata['h1'] = torch.ones(2, 1)
>>> g.edata['h1'] = torch.ones(2, 1)
>>> g.edata['h2'] = torch.zeros(2, 2)
>>> g.edata['h2'] = torch.zeros(2, 2)
>>> nx_g = dgl.to_networkx(g, node_attrs=['h'], edge_attrs=['h1', 'h2'])
>>> nx_g = dgl.to_networkx(g, node_attrs=['h'], edge_attrs=['h1', 'h2'])
>>> nx_g.nodes(data=True)
>>> nx_g.nodes(data=True)
NodeDataView({0: {'h': tensor([0.])},
NodeDataView({
1: {'h': tensor([0.])},
0: {'h': tensor([0.])},
2: {'h': tensor([0.])},
1: {'h': tensor([0.])},
3: {'h': tensor([0.])}})
2: {'h': tensor([0.])},
3: {'h': tensor([0.])}
})
>>> nx_g.edges(data=True)
>>> nx_g.edges(data=True)
OutMultiEdgeDataView([(1, 1, {'id': 0, 'h1': tensor([1.]), 'h2': tensor([0., 0.])}),
OutMultiEdgeDataView([
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})])
(1, 1, {'id': 0, 'h1': tensor([1.]), 'h2': tensor([0., 0.])}),
(2, 3, {'id': 1, 'h1': tensor([1.]), 'h2': tensor([0., 0.])})
])
With a heterogeneous graph:
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
... ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
... ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
... })
... g.ndata['n'] = {
... 'game': torch.zeros(5, 1),
... 'user': torch.ones(4, 1)
... }
... g.edata['e'] = {
... ('user', 'follows', 'user'): torch.zeros(2, 1),
... 'plays': torch.ones(2, 1)
... }
>>> nx_g = dgl.to_networkx(g, node_attrs=['n'], edge_attrs=['e'])
>>> nx_g.nodes(data=True)
NodeDataView({
0: {'ntype': 'game', 'n': tensor([0.])},
1: {'ntype': 'game', 'n': tensor([0.])},
2: {'ntype': 'game', 'n': tensor([0.])},
3: {'ntype': 'game', 'n': tensor([0.])},
4: {'ntype': 'game', 'n': tensor([0.])},
5: {'ntype': 'topic'},
6: {'ntype': 'topic'},
7: {'ntype': 'topic'},
8: {'ntype': 'user', 'n': tensor([1.])},
9: {'ntype': 'user', 'n': tensor([1.])},
10: {'ntype': 'user', 'n': tensor([1.])},
11: {'ntype': 'user', 'n': tensor([1.])}
})
>>> nx_g.edges(data=True)
OutMultiEdgeDataView([
(8, 9, {'id': 2, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),
(8, 3, {'id': 4, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])}),
(9, 6, {'id': 0, 'etype': ('user', 'follows', 'topic')}),
(9, 7, {'id': 1, 'etype': ('user', 'follows', 'topic')}),
(9, 10, {'id': 3, 'etype': ('user', 'follows', 'user'), 'e': tensor([0.])}),
(11, 4, {'id': 5, 'etype': ('user', 'plays', 'game'), 'e': tensor([1.])})
])
"""
"""
if
g
.
device
!=
F
.
cpu
():
if
g
.
device
!=
F
.
cpu
():
raise
DGLError
(
raise
DGLError
(
"Cannot convert a CUDA graph to networkx. Call g.cpu() first."
"Cannot convert a CUDA graph to networkx. Call g.cpu() first."
)
)
if
not
g
.
is_homogeneous
:
if
g
.
is_homogeneous
:
raise
DGLError
(
"dgl.to_networkx only supports homogeneous graphs."
)
return
_to_networkx_homogeneous
(
g
,
node_attrs
,
edge_attrs
)
src
,
dst
=
g
.
edges
()
else
:
src
=
F
.
asnumpy
(
src
)
return
_to_networkx_heterogeneous
(
dst
=
F
.
asnumpy
(
dst
)
g
,
node_attrs
,
edge_attrs
,
ntype_attr
,
etype_attr
,
eid_attr
# xiangsx: Always treat graph as multigraph
)
nx_graph
=
nx
.
MultiDiGraph
()
nx_graph
.
add_nodes_from
(
range
(
g
.
num_nodes
()))
for
eid
,
(
u
,
v
)
in
enumerate
(
zip
(
src
,
dst
)):
nx_graph
.
add_edge
(
u
,
v
,
id
=
eid
)
if
node_attrs
is
not
None
:
for
nid
,
attr
in
nx_graph
.
nodes
(
data
=
True
):
feat_dict
=
g
.
_get_n_repr
(
0
,
nid
)
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
node_attrs
}
)
if
edge_attrs
is
not
None
:
for
_
,
_
,
attr
in
nx_graph
.
edges
(
data
=
True
):
eid
=
attr
[
"id"
]
feat_dict
=
g
.
_get_e_repr
(
0
,
eid
)
attr
.
update
(
{
key
:
F
.
squeeze
(
feat_dict
[
key
],
0
)
for
key
in
edge_attrs
}
)
return
nx_graph
DGLGraph
.
to_networkx
=
to_networkx
DGLGraph
.
to_networkx
=
to_networkx
...
...
tests/python/common/test_convert.py
0 → 100644
View file @
a9d6f770
import
unittest
import
backend
as
F
import
dgl
from
utils
import
parametrize_idtype
def
get_nodes_by_ntype
(
nodes
,
ntype
):
return
dict
((
k
,
v
)
for
k
,
v
in
nodes
.
items
()
if
v
[
"ntype"
]
==
ntype
)
def
edge_attrs
(
edge
):
# Edges in Networkx are in the format (src, dst, attrs)
return
edge
[
2
]
def
get_edges_by_etype
(
edges
,
etype
):
return
[
e
for
e
in
edges
if
edge_attrs
(
e
)[
"etype"
]
==
etype
]
def
check_attrs_for_nodes
(
nodes
,
attrs
):
return
all
(
v
.
keys
()
==
attrs
for
v
in
nodes
.
values
())
def
check_attr_values_for_nodes
(
nodes
,
attr_name
,
values
):
return
F
.
allclose
(
F
.
stack
([
v
[
attr_name
]
for
v
in
nodes
.
values
()],
0
),
values
)
def
check_attrs_for_edges
(
edges
,
attrs
):
return
all
(
edge_attrs
(
e
).
keys
()
==
attrs
for
e
in
edges
)
def
check_attr_values_for_edges
(
edges
,
attr_name
,
values
):
return
F
.
allclose
(
F
.
stack
([
edge_attrs
(
e
)[
attr_name
]
for
e
in
edges
],
0
),
values
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"`to_networkx` does not support graphs on GPU"
,
)
@
parametrize_idtype
def
test_to_networkx
(
idtype
):
# TODO: adapt and move code from the _test_nx_conversion function in
# tests/python/common/function/test_basics.py to here
# (pending resolution of https://github.com/dmlc/dgl/issues/5735).
g
=
dgl
.
heterograph
(
{
(
"user"
,
"follows"
,
"user"
):
([
0
,
1
],
[
1
,
2
]),
(
"user"
,
"follows"
,
"topic"
):
([
1
,
1
],
[
1
,
2
]),
(
"user"
,
"plays"
,
"game"
):
([
0
,
3
],
[
3
,
4
]),
},
idtype
=
idtype
,
device
=
F
.
ctx
(),
)
n1
=
F
.
randn
((
5
,
3
))
n2
=
F
.
randn
((
4
,
2
))
e1
=
F
.
randn
((
2
,
3
))
e2
=
F
.
randn
((
2
,
2
))
g
.
nodes
[
"game"
].
data
[
"n"
]
=
F
.
copy_to
(
n1
,
ctx
=
F
.
ctx
())
g
.
nodes
[
"user"
].
data
[
"n"
]
=
F
.
copy_to
(
n2
,
ctx
=
F
.
ctx
())
g
.
edges
[(
"user"
,
"follows"
,
"user"
)].
data
[
"e"
]
=
F
.
copy_to
(
e1
,
ctx
=
F
.
ctx
())
g
.
edges
[
"plays"
].
data
[
"e"
]
=
F
.
copy_to
(
e2
,
ctx
=
F
.
ctx
())
nxg
=
dgl
.
to_networkx
(
g
,
node_attrs
=
[
"n"
],
edge_attrs
=
[
"e"
],
)
# Test nodes
nxg_nodes
=
dict
(
nxg
.
nodes
(
data
=
True
))
assert
len
(
nxg_nodes
)
==
g
.
num_nodes
()
assert
{
v
[
"ntype"
]
for
v
in
nxg_nodes
.
values
()}
==
set
(
g
.
ntypes
)
nxg_nodes_by_ntype
=
{}
for
ntype
in
g
.
ntypes
:
nxg_nodes_by_ntype
[
ntype
]
=
get_nodes_by_ntype
(
nxg_nodes
,
ntype
)
assert
g
.
num_nodes
(
ntype
)
==
len
(
nxg_nodes_by_ntype
[
ntype
])
assert
check_attrs_for_nodes
(
nxg_nodes_by_ntype
[
"game"
],
{
"ntype"
,
"n"
})
assert
check_attr_values_for_nodes
(
nxg_nodes_by_ntype
[
"game"
],
"n"
,
n1
)
assert
check_attrs_for_nodes
(
nxg_nodes_by_ntype
[
"user"
],
{
"ntype"
,
"n"
})
assert
check_attr_values_for_nodes
(
nxg_nodes_by_ntype
[
"user"
],
"n"
,
n2
)
# Nodes without node attributes
assert
check_attrs_for_nodes
(
nxg_nodes_by_ntype
[
"topic"
],
{
"ntype"
})
# Test edges
nxg_edges
=
list
(
nxg
.
edges
(
data
=
True
))
assert
len
(
nxg_edges
)
==
g
.
num_edges
()
assert
{
edge_attrs
(
e
)[
"etype"
]
for
e
in
nxg_edges
}
==
set
(
g
.
canonical_etypes
)
nxg_edges_by_etype
=
{}
for
etype
in
g
.
canonical_etypes
:
nxg_edges_by_etype
[
etype
]
=
get_edges_by_etype
(
nxg_edges
,
etype
)
assert
g
.
num_edges
(
etype
)
==
len
(
nxg_edges_by_etype
[
etype
])
assert
check_attrs_for_edges
(
nxg_edges_by_etype
[(
"user"
,
"follows"
,
"user"
)],
{
"id"
,
"etype"
,
"e"
},
)
assert
check_attr_values_for_edges
(
nxg_edges_by_etype
[(
"user"
,
"follows"
,
"user"
)],
"e"
,
e1
)
assert
check_attrs_for_edges
(
nxg_edges_by_etype
[(
"user"
,
"plays"
,
"game"
)],
{
"id"
,
"etype"
,
"e"
}
)
assert
check_attr_values_for_edges
(
nxg_edges_by_etype
[(
"user"
,
"plays"
,
"game"
)],
"e"
,
e2
)
# Edges without edge attributes
assert
check_attrs_for_edges
(
nxg_edges_by_etype
[(
"user"
,
"follows"
,
"topic"
)],
{
"id"
,
"etype"
}
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment