Unverified Commit b742c559 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature]heterograph.edata and heterograph.ndata support multiple ntype and etype (#1673)



* g.edata and g.ndata support multiple ntype and etype

* Add test case

* support g.srcdata and g.dstdata

* Fix test

* lint

* Fix

* Fix some comments

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 8d16d4bb
......@@ -76,17 +76,17 @@ class DGLHeteroGraph(object):
One can construct the graph as follows:
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
Or equivalently
>>> g = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... ('user', 'follows', 'user'): ([0, 1], [1, 2]),
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... })
:func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
......@@ -126,8 +126,8 @@ class DGLHeteroGraph(object):
For example, suppose a graph that has two types of relation "user-watches-movie"
and "user-watches-TV" as follows:
>>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV')
>>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs
To distinguish between the two "watches" edge type, one must specify a full triplet:
......@@ -387,8 +387,8 @@ class DGLHeteroGraph(object):
Examples
--------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.ntypes
['user', 'game']
......@@ -406,8 +406,8 @@ class DGLHeteroGraph(object):
Examples
--------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.etypes
['follows', 'plays']
......@@ -427,8 +427,8 @@ class DGLHeteroGraph(object):
Examples
--------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.canonical_etypes
[('user', 'follows', 'user'), ('user', 'plays', 'game')]
......@@ -469,8 +469,8 @@ class DGLHeteroGraph(object):
Examples
--------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> meta_g = g.metagraph
......@@ -512,9 +512,9 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> g3 = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'follows', 'game')
>>> g1 = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g3 = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'follows', 'game')
>>> g = dgl.hetero_from_relations([g1, g2, g3])
Get canonical edge types.
......@@ -662,7 +662,7 @@ class DGLHeteroGraph(object):
To set features of all users
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
See Also
......@@ -682,7 +682,7 @@ class DGLHeteroGraph(object):
To set features of all users
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5)
See Also
......@@ -702,7 +702,7 @@ class DGLHeteroGraph(object):
To set features of all games
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5)
See Also
......@@ -715,7 +715,12 @@ class DGLHeteroGraph(object):
def ndata(self):
"""Return the data view of all the nodes.
**Only works if the graph has one node type.**
If the graph has only one node type, ``g.ndata['feat']`` gives
the node feature data under name ``'feat'``.
If the graph has multiple node types, then ``g.ndata['feat']``
returns a dictionary where the key is the node type and the
value is the node feature tensor. If the node type does not
have feature `'feat'`, it is not included in the dictionary.
Examples
--------
......@@ -724,27 +729,60 @@ class DGLHeteroGraph(object):
To set features of all nodes in a heterogeneous graph
with only one node type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.ndata['h'] = torch.zeros(3, 5)
To set features of all nodes in a heterogeneous graph
with multiple node types:
>>> g = dgl.heterograph({('user', 'like', 'movie') : ([0, 1, 1], [1, 2, 0])})
>>> g.ndata['h'] = {'user': torch.zeros(2, 5),
... 'movie': torch.zeros(3, 5)}
>>> g.ndata['h']
... {'user': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'movie': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of nodes in a heterogeneous graph
with multiple node types:
>>> g = dgl.heterograph({('user', 'like', 'movie') : ([0, 1, 1], [1, 2, 0])})
>>> g.ndata['h'] = {'user': torch.zeros(2, 5)}
>>> g.ndata['h']
... {'user': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no node type contains 'h'
>>> g.ndata.pop('h')
>>> g.ndata['h']
... {}
See Also
--------
nodes
"""
ntid = self.get_ntype_id(None)
ntype = self.ntypes[0]
return HeteroNodeDataView(self, ntype, ntid, ALL)
if len(self.ntypes) == 1:
ntid = self.get_ntype_id(None)
ntype = self.ntypes[0]
return HeteroNodeDataView(self, ntype, ntid, ALL)
else:
ntids = [self.get_ntype_id(ntype) for ntype in self.ntypes]
ntypes = self.ntypes
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property
def srcdata(self):
"""Return the data view of all nodes in the SRC category.
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
If the source nodes have only one node type, ``g.srcdata['feat']``
gives the node feature data under name ``'feat'``.
If the source nodes have multiple node types, then
``g.srcdata['feat']`` returns a dictionary where the key is
the source node type and the value is the node feature
tensor. If the source node type does not have feature
`'feat'`, it is not included in the dictionary.
Examples
--------
......@@ -752,7 +790,7 @@ class DGLHeteroGraph(object):
To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.srcdata['h'] = torch.zeros(2, 5)
This is equivalent to
......@@ -762,13 +800,48 @@ class DGLHeteroGraph(object):
Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)],
... ('user', 'reads', 'book'), [(0, 1), (1, 0)],
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.srcdata['h'] = torch.zeros(2, 5)
To set features of all source nodes in a uni-bipartite graph
with multiple source node types:
>>> g = dgl.heterograph({
... ('game', 'liked-by', 'user') : ([1, 2], [0, 1]),
... ('book', 'liked-by', 'user') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.srcdata['h'] = {'game' : torch.zeros(3, 5),
... 'book' : torch.zeros(2, 5)}
>>> g.srcdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'book': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of source nodes in a uni-bipartite graph
with multiple source node types:
>>> g = dgl.heterograph({
... ('game', 'liked-by', 'user') : ([1, 2], [0, 1]),
... ('book', 'liked-by', 'user') : ([0, 1], [1, 0]),
... })
>>> g.srcdata['h'] = {'game' : torch.zeros(3, 5)}
>>> g.srcdata['h']
>>> {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no source node type contains 'h'
>>> g.srcdata.pop('h')
>>> g.srcdata['h']
... {}
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
......@@ -777,24 +850,27 @@ class DGLHeteroGraph(object):
--------
nodes
"""
err_msg = (
'srcdata is only allowed when there is only one %s type.' %
('SRC' if self.is_unibipartite else 'node'))
assert len(self.srctypes) == 1, err_msg
ntype = self.srctypes[0]
ntid = self.get_ntype_id_from_src(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
if len(self.srctypes) == 1:
ntype = self.srctypes[0]
ntid = self.get_ntype_id_from_src(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
else:
ntypes = self.srctypes
ntids = [self.get_ntype_id_from_src(ntype) for ntype in ntypes]
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property
def dstdata(self):
"""Return the data view of all destination nodes.
Only works if the graph is either
* Uni-bipartite and has one node type in the SRC category.
* Non-uni-bipartite and has only one node type (in this case identical to
:any:`DGLHeteroGraph.ndata`)
If the destination nodes have only one node type,
``g.dstdata['feat']`` gives the node feature data under name
``'feat'``.
If the destination nodes have multiple node types, then
``g.dstdata['feat']`` returns a dictionary where the key is
the destination node type and the value is the node feature
tensor. If the destination node type does not have feature
`'feat'`, it is not included in the dictionary.
Examples
--------
......@@ -802,7 +878,7 @@ class DGLHeteroGraph(object):
To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.dstdata['h'] = torch.zeros(3, 5)
This is equivalent to
......@@ -812,13 +888,47 @@ class DGLHeteroGraph(object):
Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)],
... ('store', 'sells', 'game'), [(0, 1), (1, 0)],
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('store', 'sells', 'game') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.dstdata['h'] = torch.zeros(3, 5)
To set features of all destination nodes in a uni-bipartite graph
with multiple destination node types::
>>> g = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.dstdata['h'] = {'game' : torch.zeros(3, 5),
... 'book' : torch.zeros(2, 5)}
>>> g.dstdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'book': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of destination nodes in a uni-bipartite graph
with multiple destination node types:
>>> g = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... })
>>> g.dstdata['h'] = {'game' : torch.zeros(3, 5)}
>>> g.dstdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no destination node type contains 'h'
>>> g.dstdata.pop('h')
>>> g.dstdata['h']
... {}
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
......@@ -827,13 +937,14 @@ class DGLHeteroGraph(object):
--------
nodes
"""
err_msg = (
'dstdata is only allowed when there is only one %s type.' %
('DST' if self.is_unibipartite else 'node'))
assert len(self.dsttypes) == 1, err_msg
ntype = self.dsttypes[0]
ntid = self.get_ntype_id_from_dst(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
if len(self.dsttypes) == 1:
ntype = self.dsttypes[0]
ntid = self.get_ntype_id_from_dst(ntype)
return HeteroNodeDataView(self, ntype, ntid, ALL)
else:
ntypes = self.dsttypes
ntids = [self.get_ntype_id_from_dst(ntype) for ntype in ntypes]
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property
def edges(self):
......@@ -846,7 +957,7 @@ class DGLHeteroGraph(object):
To set features of all "play" relationships:
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edges['plays'].data['h'] = torch.zeros(3, 4)
See Also
......@@ -859,7 +970,16 @@ class DGLHeteroGraph(object):
def edata(self):
"""Return the data view of all the edges.
**Only works if the graph has one edge type.**
If the graph has only one edge type, ``g.edata['feat']`` gives the
edge feature data under name ``'feat'``.
If the graph has multiple edge types, then ``g.edata['feat']``
returns a dictionary where the key is the edge type and the value
is the edge feature tensor. If the edge type does not have feature
``'feat'``, it is not included in the dictionary.
Note: When the graph has multiple edge type, The key used in
``g.edata['feat']`` should be the canonical_etypes, i.e.
(h_ntype, r_type, t_ntype).
Examples
--------
......@@ -868,14 +988,47 @@ class DGLHeteroGraph(object):
To set features of all edges in a heterogeneous graph
with only one edge type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.edata['h'] = torch.zeros(2, 5)
To set features of all edges in a heterogeneous graph
with multiple edge types:
>>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> g = dgl.hetero_from_relations([g0, g1])
>>> g.edata['h'] = {('user', 'watches', 'movie') : torch.zeros(3, 5),
('user', 'watches', 'TV') : torch.zeros(2, 5)}
>>> g.edata['h']
... {('user', 'watches', 'movie'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... ('user', 'watches', 'TV'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of edges in a heterogeneous graph
with multiple edge types:
>>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> g = dgl.hetero_from_relations([g0, g1])
>>> g.edata['h'] = {('user', 'watches', 'movie') : torch.zeros(3, 5)}
>>> g.edata['h']
... {('user', 'watches', 'movie'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no edge type contains 'h'
>>> g.edata.pop('h')
>>> g.edata['h']
... {}
See Also
--------
edges
"""
return HeteroEdgeDataView(self, None, ALL)
if len(self.canonical_etypes) == 1:
return HeteroEdgeDataView(self, None, ALL)
else:
return HeteroEdgeDataView(self, self.canonical_etypes, ALL)
def _find_etypes(self, key):
etypes = [
......@@ -999,7 +1152,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.number_of_nodes('user')
3
>>> g.number_of_nodes()
......@@ -1025,7 +1178,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.number_of_src_nodes('user')
2
>>> g.number_of_src_nodes()
......@@ -1053,7 +1206,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.number_of_dst_nodes('game')
3
>>> g.number_of_dst_nodes()
......@@ -1080,8 +1233,9 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.number_of_edges(('user', 'follows', 'user'))
2
>>> g.number_of_edges('follows')
2
>>> g.number_of_edges()
......@@ -1290,8 +1444,8 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([plays_g, devs_g])
>>> g.predecessors(0, 'plays')
tensor([0, 1])
......@@ -1328,8 +1482,8 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> g.successors(0, 'plays')
tensor([0])
......@@ -1381,8 +1535,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for edge id.
......@@ -1456,8 +1610,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for edge ids.
......@@ -1515,7 +1669,7 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.find_edges([0, 2], ('user', 'plays', 'game'))
(tensor([0, 1]), tensor([0, 2]))
>>> g.find_edges([0, 2])
......@@ -1572,7 +1726,7 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 1, 2]), 'user', 'plays', 'game')
>>> g.in_edges([0, 2], form='eid')
tensor([0, 2])
>>> g.in_edges([0, 2], form='all')
......@@ -1624,7 +1778,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 1, 2]), 'user', 'plays', 'game')
>>> g.out_edges([0, 1], form='eid')
tensor([0, 1, 2])
>>> g.out_edges([0, 1], form='all')
......@@ -1680,7 +1834,7 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g = dgl.bipartite([(1, 1), (0, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([1, 0, 1], [1, 0, 2]), 'user', 'plays', 'game')
>>> g.all_edges(form='eid', order='srcdst')
tensor([1, 0, 2])
>>> g.all_edges(form='all', order='srcdst')
......@@ -1719,8 +1873,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree.
......@@ -1760,8 +1914,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree.
......@@ -1805,8 +1959,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree.
......@@ -1846,8 +2000,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree.
......@@ -1925,8 +2079,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
......@@ -2014,8 +2168,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
......@@ -2091,8 +2245,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
......@@ -2172,8 +2326,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
......@@ -2256,8 +2410,8 @@ class DGLHeteroGraph(object):
Instantiate a heterogeneous graph.
>>> follows_g = dgl.graph([(0, 0), (1, 1)], 'user', 'follows')
>>> devs_g = dgl.bipartite([(0, 0), (1, 2)], 'developer', 'develops', 'game')
>>> follows_g = dgl.graph(([0, 1], [0, 1]), 'user', 'follows')
>>> devs_g = dgl.bipartite(([0, 1], [0, 2]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, devs_g])
Get a backend dependent sparse tensor. Here we use PyTorch for example.
......@@ -2337,7 +2491,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [0, 2]), 'user', 'follows')
>>> g.incidence_matrix('in')
tensor(indices=tensor([[0, 2],
[0, 1]]),
......@@ -2386,7 +2540,7 @@ class DGLHeteroGraph(object):
--------
The following uses PyTorch backend.
>>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [0, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)}
......@@ -2418,7 +2572,7 @@ class DGLHeteroGraph(object):
--------
The following uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
>>> g.edge_attr_schemes(('user', 'plays', 'game'))
{'h': Scheme(shape=(4,), dtype=torch.float32)}
......@@ -2716,7 +2870,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.ones(3, 5)
>>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
>>> g.nodes['user'].data['h']
......@@ -2762,7 +2916,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
>>> g.edges[('user', 'plays', 'game')].data['h']
......@@ -2824,7 +2978,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.graph([(0, 1), (0, 2), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 0, 1], [1, 2, 2]), 'user', 'follows')
>>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
>>> def softmax_feat(edges):
>>> return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
......@@ -2908,7 +3062,7 @@ class DGLHeteroGraph(object):
>>> import dgl.function as fn
>>> import torch
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Different ways for sending messages.
......@@ -3082,8 +3236,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
>>> g1 = dgl.graph(([0], [1]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
......@@ -3187,8 +3341,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
Trigger "send" and "receive" separately.
......@@ -3280,8 +3434,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1)], 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
>>> g1 = dgl.graph(([0], [1]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
Trigger send and recv separately.
......@@ -3411,8 +3565,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (2, 1)], 'user', 'plays', 'game')
>>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite(([0, 2], [0, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
......@@ -3482,8 +3636,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g1 = dgl.graph([(1, 1), (1, 0)], 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
>>> g1 = dgl.graph(([1, 1], [1, 0]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
Pull.
......@@ -3580,7 +3734,7 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g = dgl.graph([(0, 1), (0, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 0], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Push.
......@@ -3648,7 +3802,7 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g = dgl.graph([(0, 1), (1, 2), (2, 2)], 'user', 'follows')
>>> g = dgl.graph(([0, 1, 2], [1, 2, 2]), 'user', 'follows')
Update all.
......@@ -3712,8 +3866,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1), (1, 1)], 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user')
>>> g1 = dgl.graph(([0, 1], [1, 1]), 'user', 'follows')
>>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
......@@ -4013,7 +4167,7 @@ class DGLHeteroGraph(object):
>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows')
>>> g = dgl.graph(([0, 0, 1, 2], [0, 1, 2, 3]), 'user', 'follows')
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
tensor([1, 2])
......@@ -4055,7 +4209,7 @@ class DGLHeteroGraph(object):
--------
The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> print(g.device)
device(type='cpu')
>>> g = g.to('cuda:0')
......@@ -4088,7 +4242,7 @@ class DGLHeteroGraph(object):
The following example uses PyTorch backend.
>>> import torch
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
>>> g1 = g.to(torch.device('cuda:0'))
......@@ -4146,7 +4300,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h']
>>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.edata['h']) # still get tensor of all zeros
......@@ -4160,7 +4314,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h']
>>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> h = foo(g)
>>> print('h' in g.edata)
False
......@@ -4198,7 +4352,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h']
>>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
>>> newh = foo(g) # get tensor of all ones
>>> print(g.edata['h']) # still get tensor of all zeros
......@@ -4212,7 +4366,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h']
>>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game')
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> h = foo(g)
>>> print('h' in g.edata)
False
......@@ -4251,15 +4405,15 @@ class DGLHeteroGraph(object):
--------
For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr')
>>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows', restrict_format='csr')
>>> g.format_in_use()
['csr']
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
......@@ -4306,8 +4460,8 @@ class DGLHeteroGraph(object):
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any')
>>> g.restrict_format('develops')
'any'
......@@ -4354,8 +4508,8 @@ class DGLHeteroGraph(object):
For a graph with multiple types.
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any')
>>> g.format_in_use('develops')
['coo']
......@@ -4416,8 +4570,8 @@ class DGLHeteroGraph(object):
For a graph with multiple edge types:
>>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)],
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)],
... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='coo')
>>> g.restrict_format('develops')
'coo'
......@@ -4447,7 +4601,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game',
>>> index_dtype='int32')
>>> g_long = g.long() # Convert g to int64 indexed, not changing the original `g`
......@@ -4472,7 +4626,7 @@ class DGLHeteroGraph(object):
Examples
--------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game',
>>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game',
>>> index_dtype='int64')
>>> g_int = g.int() # Convert g to int32 indexed, not changing the original `g`
......
......@@ -291,24 +291,65 @@ class HeteroNodeDataView(MutableMapping):
self._nodes = nodes
def __getitem__(self, key):
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(key, None)
if value is not None:
ret[ntype] = value
return ret
else:
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val):
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
if isinstance(self._ntype, list):
assert isinstance(val, dict), \
'Current HeteroNodeDataView has multiple node types, ' \
'please passing the node type and the corresponding data through a dict.'
for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly'
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key):
self._graph._pop_n_repr(self._ntid, key)
if isinstance(self._ntype, list):
for ntid in self._ntid:
if self._graph._get_n_repr(ntid, ALL).get(key, None) is None:
continue
self._graph._pop_n_repr(ntid, key)
else:
self._graph._pop_n_repr(self._ntid, key)
def __len__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not support len().'
return len(self._graph._node_frames[self._ntid])
def __iter__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not be iterated.'
return iter(self._graph._node_frames[self._ntid])
def __repr__(self):
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._ntid]})
if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
data = self._graph._get_n_repr(self._ntid[i], self._nodes)
value = {key : data[key]
for key in self._graph._node_frames[self._ntid[i]]}
ret[ntype] = value
return repr(ret)
else:
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
......@@ -345,31 +386,74 @@ class HeteroEdgeView(object):
return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.ndata[etype] is called."""
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._etid = self._graph.get_etype_id(etype)
self._etid = [self._graph.get_etype_id(t) for t in etype] \
if isinstance(etype, list) \
else self._graph.get_etype_id(etype)
self._edges = edges
def __getitem__(self, key):
return self._graph._get_e_repr(self._etid, self._edges)[key]
if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(self._etid[i], self._edges).get(key, None)
if value is not None:
ret[etype] = value
return ret
else:
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val):
self._graph._set_e_repr(self._etid, self._edges, {key : val})
if isinstance(self._etype, list):
assert isinstance(val, dict), \
'Current HeteroEdgeDataView has multiple edge types, ' \
'please pass the edge type and the corresponding data through a dict.'
for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly'
self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key):
self._graph._pop_e_repr(self._etid, key)
if isinstance(self._etype, list):
for etid in self._etid:
if self._graph._get_e_repr(etid, ALL).get(key, None) is None:
continue
self._graph._pop_e_repr(etid, key)
else:
self._graph._pop_e_repr(self._etid, key)
def __len__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not support len().'
return len(self._graph._edge_frames[self._etid])
def __iter__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not be iterated.'
return iter(self._graph._edge_frames[self._etid])
def __repr__(self):
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._etid]})
if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
data = self._graph._get_e_repr(self._etid[i], self._edges)
value = {key : data[key]
for key in self._graph._edge_frames[self._etid[i]]}
ret[etype] = value
return repr(ret)
else:
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._etid]})
......@@ -494,6 +494,31 @@ def test_inc(index_dtype):
@parametrize_dtype
def test_view(index_dtype):
# test single node type
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
f1 = F.randn((3, 6))
g.ndata['h'] = f1
f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2)
fail = False
try:
g.ndata['h'] = {'user' : f1}
except Exception:
fail = True
assert fail
# test single edge type
f3 = F.randn((2, 4))
g.edata['h'] = f3
f4 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
fail = False
try:
g.edata['h'] = {'follows' : f3}
except Exception:
fail = True
assert fail
# test data view
g = create_test_heterograph(index_dtype)
......@@ -502,6 +527,31 @@ def test_view(index_dtype):
f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.nodes('user')), F.arange(0, 3))
g.nodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.ndata['h'] = f1
except Exception:
fail = True
assert fail
g.ndata['h'] = {'user' : f1,
'game' : f2}
f3 = g.nodes['user'].data['h']
f4 = g.nodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.ndata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.ndata)
g.ndata.pop('h')
# test repr
print(g.ndata)
f3 = F.randn((2, 4))
g.edges['user', 'follows', 'user'].data['h'] = f3
......@@ -510,6 +560,87 @@ def test_view(index_dtype):
assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5)
assert F.array_equal(F.tensor(g.edges(etype='follows', form='eid')), F.arange(0, 2))
g.edges['follows'].data.pop('h')
f3 = F.randn((2, 4))
fail = False
try:
g.edata['h'] = f3
except Exception:
fail = True
assert fail
g.edata['h'] = {('user', 'follows', 'user') : f3}
f4 = g.edges['user', 'follows', 'user'].data['h']
f5 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5)
data = g.edata['h']
assert F.array_equal(f3, data[('user', 'follows', 'user')])
# test repr
print(g.edata)
g.edata.pop('h')
# test repr
print(g.edata)
# test srcdata
f1 = F.randn((3, 6))
g.srcnodes['user'].data['h'] = f1 # ok
f2 = g.srcnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.srcnodes('user')), F.arange(0, 3))
g.srcnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.srcdata['h'] = f1
except Exception:
fail = True
assert fail
g.srcdata['h'] = {'user' : f1,
'developer' : f2}
f3 = g.srcnodes['user'].data['h']
f4 = g.srcnodes['developer'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.srcdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['developer'])
# test repr
print(g.srcdata)
g.srcdata.pop('h')
# test dstdata
f1 = F.randn((3, 6))
g.dstnodes['user'].data['h'] = f1 # ok
f2 = g.dstnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.dstnodes('user')), F.arange(0, 3))
g.dstnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.dstdata['h'] = f1
except Exception:
fail = True
assert fail
g.dstdata['h'] = {'user' : f1,
'game' : f2}
f3 = g.dstnodes['user'].data['h']
f4 = g.dstnodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.dstdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.dstdata)
g.dstdata.pop('h')
@parametrize_dtype
def test_view1(index_dtype):
......@@ -639,21 +770,14 @@ def test_view1(index_dtype):
assert F.array_equal(f3, f4)
assert F.array_equal(F.tensor(g.edges(form='eid')), F.arange(0, 2))
# test fail case
# fail due to multiple types
fail = False
try:
HG.ndata['h']
except dgl.DGLError:
fail = True
assert fail
fail = False
try:
HG.edata['h']
except dgl.DGLError:
fail = True
assert fail
# multiple types
ndata = HG.ndata['h']
assert isinstance(ndata, dict)
assert F.array_equal(ndata['user'], f2)
edata = HG.edata['h']
assert isinstance(edata, dict)
assert F.array_equal(edata[('user', 'follows', 'user')], f4)
@parametrize_dtype
def test_flatten(index_dtype):
......@@ -1740,7 +1864,7 @@ if __name__ == '__main__':
# test_hypersparse()
# test_adj("int32")
# test_inc()
# test_view()
# test_view("int32")
# test_view1("int32")
# test_flatten()
# test_convert_bound()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment