Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ac570c1d
Unverified
Commit
ac570c1d
authored
Sep 10, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Sep 10, 2020
Browse files
[Bugfix] Fix flatten not wrapping unit graph (#2170)
* fix flatten not wrapping unit graph * fix doc
parent
2c04ecb5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
3 deletions
+60
-3
python/dgl/heterograph.py
python/dgl/heterograph.py
+57
-2
src/graph/heterograph.cc
src/graph/heterograph.cc
+1
-1
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+2
-0
No files found.
python/dgl/heterograph.py
View file @
ac570c1d
...
...
@@ -1875,7 +1875,7 @@ class DGLHeteroGraph(object):
def
__getitem__
(
self
,
key
):
"""Return the relation slice of this graph.
A
relation slice
is accessed
with ``self[srctype, etype, dsttype]``, where
You can get a
relation slice with ``self[srctype, etype, dsttype]``, where
``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
slice (``:``) representing wildcard (i.e. any source/edge/destination type).
...
...
@@ -1893,8 +1893,63 @@ class DGLHeteroGraph(object):
new source/destination node type would have the concatenation determined by
:func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
types as its name. The source/destination node would be formed by concatenating the
common features of the original source/destination types
, t
herefore they are not
common features of the original source/destination types
. T
herefore they are not
shared with the original graph. Edge type is similar.
Parameters
----------
key : str or tuple
Either a string representing the edge type name, or a tuple in the form of
``(srctype, etype, dsttype)`` where ``srctype``, ``etype``, ``dsttype`` can be either
strings representing type names or a full slice object (`:`).
Returns
-------
DGLGraph
The relation slice.
Notes
-----
This function returns a new graph. Changing the content of this graph does not reflect
onto the original graph.
If the graph combines multiple node types or edge types together, it will have the
mapping of node/edge types and IDs from the new graph to the original graph.
The mappings have the name ``dgl.NTYPE``, ``dgl.NID``, ``dgl.ETYPE`` and ``dgl.EID``,
similar to the function :func:`dgl.to_homogenenous`.
Examples
--------
>>> g = dgl.heterograph({
... ('A1', 'AB1', 'B'): ([0, 1, 2], [1, 2, 3]),
... ('A1', 'AB2', 'B'): ([1, 2, 3], [3, 4, 5]),
... ('A2', 'AB2', 'B'): ([1, 3, 5], [2, 4, 6])})
>>> new_g = g['A1', :, 'B'] # combines all edge types between A1 and B
>>> new_g
Graph(num_nodes={'A1': 4, 'B': 7},
num_edges={('A1', 'AB1+AB2', 'B'): 6},
metagraph=[('A1', 'B', 'AB1+AB2')])
>>> new_g.edges()
(tensor([0, 1, 2, 1, 2, 3]), tensor([1, 2, 3, 3, 4, 5]))
>>> new_g2 = g[:, 'AB2', 'B'] # combines all node types that are source of AB2
>>> new_g2
Graph(num_nodes={'A1+A2': 10, 'B': 7},
num_edges={('A1+A2', 'AB2+AB2', 'B'): 6},
metagraph=[('A1+A2', 'B', 'AB2+AB2')])
>>> new_g2.edges()
(tensor([1, 2, 3, 5, 7, 9]), tensor([3, 4, 5, 2, 4, 6]))
If a combination of multiple node types and edge types occur, one can find
the mapping to the original node type and IDs like the following:
>>> new_g1.edges['AB1+AB2'].data[dgl.EID]
tensor([0, 1, 2, 0, 1, 2])
>>> new_g1.edges['AB1+AB2'].data[dgl.ETYPE]
tensor([0, 0, 0, 1, 1, 1])
>>> new_g2.nodes['A1+A2'].data[dgl.NID]
tensor([0, 1, 2, 3, 0, 1, 2, 3, 4, 5])
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
err_msg
=
"Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] "
+
\
"to get view of one relation type. Use : to slice multiple types (e.g. "
+
\
...
...
src/graph/heterograph.cc
View file @
ac570c1d
...
...
@@ -484,7 +484,7 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
CHECK_EQ
(
gptr
->
NumBits
(),
NumBits
());
FlattenedHeteroGraph
*
result
=
new
FlattenedHeteroGraph
;
result
->
graph
=
HeteroGraphRef
(
gptr
);
result
->
graph
=
HeteroGraphRef
(
HeteroGraphPtr
(
new
HeteroGraph
(
gptr
->
meta_graph
(),
{
gptr
}))
);
result
->
induced_srctype
=
aten
::
VecToIdArray
(
induced_srctype
).
CopyTo
(
Context
());
result
->
induced_srctype_set
=
aten
::
VecToIdArray
(
srctype_set
).
CopyTo
(
Context
());
result
->
induced_srcid
=
aten
::
VecToIdArray
(
induced_srcid
).
CopyTo
(
Context
());
...
...
tests/compute/test_heterograph.py
View file @
ac570c1d
...
...
@@ -794,6 +794,8 @@ def test_flatten(idtype):
assert
fg
.
etypes
==
[
'plays+wishes'
]
assert
fg
.
idtype
==
g
.
idtype
assert
fg
.
device
==
g
.
device
etype
=
fg
.
etypes
[
0
]
assert
fg
[
etype
]
is
not
None
# Issue #2166
assert
F
.
array_equal
(
fg
.
nodes
[
'user'
].
data
[
'h'
],
F
.
ones
((
3
,
5
)))
assert
F
.
array_equal
(
fg
.
nodes
[
'game'
].
data
[
'i'
],
F
.
ones
((
2
,
5
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment