Unverified Commit ac570c1d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix flatten not wrapping unit graph (#2170)

* fix flatten not wrapping unit graph

* fix doc
parent 2c04ecb5
......@@ -1875,7 +1875,7 @@ class DGLHeteroGraph(object):
def __getitem__(self, key):
"""Return the relation slice of this graph.
A relation slice is accessed with ``self[srctype, etype, dsttype]``, where
You can get a relation slice with ``self[srctype, etype, dsttype]``, where
``srctype``, ``etype``, and ``dsttype`` can be either a string or a full
slice (``:``) representing wildcard (i.e. any source/edge/destination type).
......@@ -1893,8 +1893,63 @@ class DGLHeteroGraph(object):
new source/destination node type would have the concatenation determined by
:func:`dgl.combine_names() <dgl.combine_names>` called on original source/destination
types as its name. The source/destination node would be formed by concatenating the
common features of the original source/destination types, therefore they are not
common features of the original source/destination types. Therefore they are not
shared with the original graph. Edge type is similar.
Parameters
----------
key : str or tuple
Either a string representing the edge type name, or a tuple in the form of
``(srctype, etype, dsttype)`` where ``srctype``, ``etype``, ``dsttype`` can be either
strings representing type names or a full slice object (`:`).
Returns
-------
DGLGraph
The relation slice.
Notes
-----
This function returns a new graph. Changing the content of this graph does not reflect
onto the original graph.
If the graph combines multiple node types or edge types together, it will have the
mapping of node/edge types and IDs from the new graph to the original graph.
The mappings have the name ``dgl.NTYPE``, ``dgl.NID``, ``dgl.ETYPE`` and ``dgl.EID``,
similar to the function :func:`dgl.to_homogenenous`.
Examples
--------
>>> g = dgl.heterograph({
... ('A1', 'AB1', 'B'): ([0, 1, 2], [1, 2, 3]),
... ('A1', 'AB2', 'B'): ([1, 2, 3], [3, 4, 5]),
... ('A2', 'AB2', 'B'): ([1, 3, 5], [2, 4, 6])})
>>> new_g = g['A1', :, 'B'] # combines all edge types between A1 and B
>>> new_g
Graph(num_nodes={'A1': 4, 'B': 7},
num_edges={('A1', 'AB1+AB2', 'B'): 6},
metagraph=[('A1', 'B', 'AB1+AB2')])
>>> new_g.edges()
(tensor([0, 1, 2, 1, 2, 3]), tensor([1, 2, 3, 3, 4, 5]))
>>> new_g2 = g[:, 'AB2', 'B'] # combines all node types that are source of AB2
>>> new_g2
Graph(num_nodes={'A1+A2': 10, 'B': 7},
num_edges={('A1+A2', 'AB2+AB2', 'B'): 6},
metagraph=[('A1+A2', 'B', 'AB2+AB2')])
>>> new_g2.edges()
(tensor([1, 2, 3, 5, 7, 9]), tensor([3, 4, 5, 2, 4, 6]))
If a combination of multiple node types and edge types occur, one can find
the mapping to the original node type and IDs like the following:
>>> new_g1.edges['AB1+AB2'].data[dgl.EID]
tensor([0, 1, 2, 0, 1, 2])
>>> new_g1.edges['AB1+AB2'].data[dgl.ETYPE]
tensor([0, 0, 0, 1, 1, 1])
>>> new_g2.nodes['A1+A2'].data[dgl.NID]
tensor([0, 1, 2, 3, 0, 1, 2, 3, 4, 5])
>>> new_g2.nodes['A1+A2'].data[dgl.NTYPE]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
"""
err_msg = "Invalid slice syntax. Use G['etype'] or G['srctype', 'etype', 'dsttype'] " +\
"to get view of one relation type. Use : to slice multiple types (e.g. " +\
......
......@@ -484,7 +484,7 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
CHECK_EQ(gptr->NumBits(), NumBits());
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(gptr);
result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
......
......@@ -794,6 +794,8 @@ def test_flatten(idtype):
assert fg.etypes == ['plays+wishes']
assert fg.idtype == g.idtype
assert fg.device == g.device
etype = fg.etypes[0]
assert fg[etype] is not None # Issue #2166
assert F.array_equal(fg.nodes['user'].data['h'], F.ones((3, 5)))
assert F.array_equal(fg.nodes['game'].data['i'], F.ones((2, 5)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment