Unverified Commit b742c559 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature]heterograph.edata and heterograph.ndata support multiple ntype and etype (#1673)



* g.edata and g.ndata support multiple ntype and etype

* Add test case

* support g.srcdata and g.dstdata

* Fix test

* lint

* Fix

* Fix some comments

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 8d16d4bb
...@@ -76,17 +76,17 @@ class DGLHeteroGraph(object): ...@@ -76,17 +76,17 @@ class DGLHeteroGraph(object):
One can construct the graph as follows: One can construct the graph as follows:
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game') >>> devs_g = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
Or equivalently Or equivalently
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)], ... ('user', 'follows', 'user'): ([0, 1], [1, 2]),
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)], ... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }) ... })
:func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of :func:`dgl.graph` and :func:`dgl.bipartite` can create a graph from a variety of
...@@ -126,8 +126,8 @@ class DGLHeteroGraph(object): ...@@ -126,8 +126,8 @@ class DGLHeteroGraph(object):
For example, suppose a graph that has two types of relation "user-watches-movie" For example, suppose a graph that has two types of relation "user-watches-movie"
and "user-watches-TV" as follows: and "user-watches-TV" as follows:
>>> g0 = dgl.bipartite([(0, 1), (1, 0), (1, 1)], 'user', 'watches', 'movie') >>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite([(0, 0), (1, 1)], 'user', 'watches', 'TV') >>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs >>> GG = dgl.hetero_from_relations([g0, g1]) # Merge the two graphs
To distinguish between the two "watches" edge type, one must specify a full triplet: To distinguish between the two "watches" edge type, one must specify a full triplet:
...@@ -387,8 +387,8 @@ class DGLHeteroGraph(object): ...@@ -387,8 +387,8 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.ntypes >>> g.ntypes
['user', 'game'] ['user', 'game']
...@@ -406,8 +406,8 @@ class DGLHeteroGraph(object): ...@@ -406,8 +406,8 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.etypes >>> g.etypes
['follows', 'plays'] ['follows', 'plays']
...@@ -427,8 +427,8 @@ class DGLHeteroGraph(object): ...@@ -427,8 +427,8 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.canonical_etypes >>> g.canonical_etypes
[('user', 'follows', 'user'), ('user', 'plays', 'game')] [('user', 'follows', 'user'), ('user', 'plays', 'game')]
...@@ -469,8 +469,8 @@ class DGLHeteroGraph(object): ...@@ -469,8 +469,8 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> meta_g = g.metagraph >>> meta_g = g.metagraph
...@@ -512,9 +512,9 @@ class DGLHeteroGraph(object): ...@@ -512,9 +512,9 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g1 = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> g2 = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g3 = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'follows', 'game') >>> g3 = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'follows', 'game')
>>> g = dgl.hetero_from_relations([g1, g2, g3]) >>> g = dgl.hetero_from_relations([g1, g2, g3])
Get canonical edge types. Get canonical edge types.
...@@ -662,7 +662,7 @@ class DGLHeteroGraph(object): ...@@ -662,7 +662,7 @@ class DGLHeteroGraph(object):
To set features of all users To set features of all users
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.zeros(3, 5) >>> g.nodes['user'].data['h'] = torch.zeros(3, 5)
See Also See Also
...@@ -682,7 +682,7 @@ class DGLHeteroGraph(object): ...@@ -682,7 +682,7 @@ class DGLHeteroGraph(object):
To set features of all users To set features of all users
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5) >>> g.srcnodes['user'].data['h'] = torch.zeros(2, 5)
See Also See Also
...@@ -702,7 +702,7 @@ class DGLHeteroGraph(object): ...@@ -702,7 +702,7 @@ class DGLHeteroGraph(object):
To set features of all games To set features of all games
>>> g = dgl.biparite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5) >>> g.dstnodes['game'].data['h'] = torch.zeros(3, 5)
See Also See Also
...@@ -715,7 +715,12 @@ class DGLHeteroGraph(object): ...@@ -715,7 +715,12 @@ class DGLHeteroGraph(object):
def ndata(self): def ndata(self):
"""Return the data view of all the nodes. """Return the data view of all the nodes.
**Only works if the graph has one node type.** If the graph has only one node type, ``g.ndata['feat']`` gives
the node feature data under name ``'feat'``.
If the graph has multiple node types, then ``g.ndata['feat']``
returns a dictionary where the key is the node type and the
value is the node feature tensor. If the node type does not
have feature `'feat'`, it is not included in the dictionary.
Examples Examples
-------- --------
...@@ -724,27 +729,60 @@ class DGLHeteroGraph(object): ...@@ -724,27 +729,60 @@ class DGLHeteroGraph(object):
To set features of all nodes in a heterogeneous graph To set features of all nodes in a heterogeneous graph
with only one node type: with only one node type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.ndata['h'] = torch.zeros(3, 5) >>> g.ndata['h'] = torch.zeros(3, 5)
To set features of all nodes in a heterogeneous graph
with multiple node types:
>>> g = dgl.heterograph({('user', 'like', 'movie') : ([0, 1, 1], [1, 2, 0])})
>>> g.ndata['h'] = {'user': torch.zeros(2, 5),
... 'movie': torch.zeros(3, 5)}
>>> g.ndata['h']
... {'user': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'movie': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of nodes in a heterogeneous graph
with multiple node types:
>>> g = dgl.heterograph({('user', 'like', 'movie') : ([0, 1, 1], [1, 2, 0])})
>>> g.ndata['h'] = {'user': torch.zeros(2, 5)}
>>> g.ndata['h']
... {'user': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no node type contains 'h'
>>> g.ndata.pop('h')
>>> g.ndata['h']
... {}
See Also See Also
-------- --------
nodes nodes
""" """
ntid = self.get_ntype_id(None) if len(self.ntypes) == 1:
ntype = self.ntypes[0] ntid = self.get_ntype_id(None)
return HeteroNodeDataView(self, ntype, ntid, ALL) ntype = self.ntypes[0]
return HeteroNodeDataView(self, ntype, ntid, ALL)
else:
ntids = [self.get_ntype_id(ntype) for ntype in self.ntypes]
ntypes = self.ntypes
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property @property
def srcdata(self): def srcdata(self):
"""Return the data view of all nodes in the SRC category. """Return the data view of all nodes in the SRC category.
Only works if the graph is either If the source nodes have only one node type, ``g.srcdata['feat']``
gives the node feature data under name ``'feat'``.
* Uni-bipartite and has one node type in the SRC category. If the source nodes have multiple node types, then
``g.srcdata['feat']`` returns a dictionary where the key is
* Non-uni-bipartite and has only one node type (in this case identical to the source node type and the value is the node feature
:any:`DGLHeteroGraph.ndata`) tensor. If the source node type does not have feature
`'feat'`, it is not included in the dictionary.
Examples Examples
-------- --------
...@@ -752,7 +790,7 @@ class DGLHeteroGraph(object): ...@@ -752,7 +790,7 @@ class DGLHeteroGraph(object):
To set features of all source nodes in a graph with only one edge type: To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.srcdata['h'] = torch.zeros(2, 5) >>> g.srcdata['h'] = torch.zeros(2, 5)
This is equivalent to This is equivalent to
...@@ -762,13 +800,48 @@ class DGLHeteroGraph(object): ...@@ -762,13 +800,48 @@ class DGLHeteroGraph(object):
Also work on more complex uni-bipartite graph Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)], ... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book'), [(0, 1), (1, 0)], ... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... }) ... })
>>> print(g.is_unibipartite) >>> print(g.is_unibipartite)
True True
>>> g.srcdata['h'] = torch.zeros(2, 5) >>> g.srcdata['h'] = torch.zeros(2, 5)
To set features of all source nodes in a uni-bipartite graph
with multiple source node types:
>>> g = dgl.heterograph({
... ('game', 'liked-by', 'user') : ([1, 2], [0, 1]),
... ('book', 'liked-by', 'user') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.srcdata['h'] = {'game' : torch.zeros(3, 5),
... 'book' : torch.zeros(2, 5)}
>>> g.srcdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'book': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of source nodes in a uni-bipartite graph
with multiple source node types:
>>> g = dgl.heterograph({
... ('game', 'liked-by', 'user') : ([1, 2], [0, 1]),
... ('book', 'liked-by', 'user') : ([0, 1], [1, 0]),
... })
>>> g.srcdata['h'] = {'game' : torch.zeros(3, 5)}
>>> g.srcdata['h']
>>> {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no source node type contains 'h'
>>> g.srcdata.pop('h')
>>> g.srcdata['h']
... {}
Notes Notes
----- -----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous. This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
...@@ -777,24 +850,27 @@ class DGLHeteroGraph(object): ...@@ -777,24 +850,27 @@ class DGLHeteroGraph(object):
-------- --------
nodes nodes
""" """
err_msg = ( if len(self.srctypes) == 1:
'srcdata is only allowed when there is only one %s type.' % ntype = self.srctypes[0]
('SRC' if self.is_unibipartite else 'node')) ntid = self.get_ntype_id_from_src(ntype)
assert len(self.srctypes) == 1, err_msg return HeteroNodeDataView(self, ntype, ntid, ALL)
ntype = self.srctypes[0] else:
ntid = self.get_ntype_id_from_src(ntype) ntypes = self.srctypes
return HeteroNodeDataView(self, ntype, ntid, ALL) ntids = [self.get_ntype_id_from_src(ntype) for ntype in ntypes]
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property @property
def dstdata(self): def dstdata(self):
"""Return the data view of all destination nodes. """Return the data view of all destination nodes.
Only works if the graph is either If the destination nodes have only one node type,
``g.dstdata['feat']`` gives the node feature data under name
* Uni-bipartite and has one node type in the SRC category. ``'feat'``.
If the destination nodes have multiple node types, then
* Non-uni-bipartite and has only one node type (in this case identical to ``g.dstdata['feat']`` returns a dictionary where the key is
:any:`DGLHeteroGraph.ndata`) the destination node type and the value is the node feature
tensor. If the destination node type does not have feature
`'feat'`, it is not included in the dictionary.
Examples Examples
-------- --------
...@@ -802,7 +878,7 @@ class DGLHeteroGraph(object): ...@@ -802,7 +878,7 @@ class DGLHeteroGraph(object):
To set features of all source nodes in a graph with only one edge type: To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.dstdata['h'] = torch.zeros(3, 5) >>> g.dstdata['h'] = torch.zeros(3, 5)
This is equivalent to This is equivalent to
...@@ -812,13 +888,47 @@ class DGLHeteroGraph(object): ...@@ -812,13 +888,47 @@ class DGLHeteroGraph(object):
Also work on more complex uni-bipartite graph Also work on more complex uni-bipartite graph
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'), [(0, 1), (1, 2)], ... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('store', 'sells', 'game'), [(0, 1), (1, 0)], ... ('store', 'sells', 'game') : ([0, 1], [1, 0]),
... }) ... })
>>> print(g.is_unibipartite) >>> print(g.is_unibipartite)
True True
>>> g.dstdata['h'] = torch.zeros(3, 5) >>> g.dstdata['h'] = torch.zeros(3, 5)
To set features of all destination nodes in a uni-bipartite graph
with multiple destination node types::
>>> g = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... })
>>> print(g.is_unibipartite)
True
>>> g.dstdata['h'] = {'game' : torch.zeros(3, 5),
... 'book' : torch.zeros(2, 5)}
>>> g.dstdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... 'book': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of destination nodes in a uni-bipartite graph
with multiple destination node types:
>>> g = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1], [1, 2]),
... ('user', 'reads', 'book') : ([0, 1], [1, 0]),
... })
>>> g.dstdata['h'] = {'game' : torch.zeros(3, 5)}
>>> g.dstdata['h']
... {'game': tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no destination node type contains 'h'
>>> g.dstdata.pop('h')
>>> g.dstdata['h']
... {}
Notes Notes
----- -----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous. This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
...@@ -827,13 +937,14 @@ class DGLHeteroGraph(object): ...@@ -827,13 +937,14 @@ class DGLHeteroGraph(object):
-------- --------
nodes nodes
""" """
err_msg = ( if len(self.dsttypes) == 1:
'dstdata is only allowed when there is only one %s type.' % ntype = self.dsttypes[0]
('DST' if self.is_unibipartite else 'node')) ntid = self.get_ntype_id_from_dst(ntype)
assert len(self.dsttypes) == 1, err_msg return HeteroNodeDataView(self, ntype, ntid, ALL)
ntype = self.dsttypes[0] else:
ntid = self.get_ntype_id_from_dst(ntype) ntypes = self.dsttypes
return HeteroNodeDataView(self, ntype, ntid, ALL) ntids = [self.get_ntype_id_from_dst(ntype) for ntype in ntypes]
return HeteroNodeDataView(self, ntypes, ntids, ALL)
@property @property
def edges(self): def edges(self):
...@@ -846,7 +957,7 @@ class DGLHeteroGraph(object): ...@@ -846,7 +957,7 @@ class DGLHeteroGraph(object):
To set features of all "play" relationships: To set features of all "play" relationships:
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edges['plays'].data['h'] = torch.zeros(3, 4) >>> g.edges['plays'].data['h'] = torch.zeros(3, 4)
See Also See Also
...@@ -859,7 +970,16 @@ class DGLHeteroGraph(object): ...@@ -859,7 +970,16 @@ class DGLHeteroGraph(object):
def edata(self): def edata(self):
"""Return the data view of all the edges. """Return the data view of all the edges.
**Only works if the graph has one edge type.** If the graph has only one edge type, ``g.edata['feat']`` gives the
edge feature data under name ``'feat'``.
If the graph has multiple edge types, then ``g.edata['feat']``
returns a dictionary where the key is the edge type and the value
is the edge feature tensor. If the edge type does not have feature
``'feat'``, it is not included in the dictionary.
Note: When the graph has multiple edge type, The key used in
``g.edata['feat']`` should be the canonical_etypes, i.e.
(h_ntype, r_type, t_ntype).
Examples Examples
-------- --------
...@@ -868,14 +988,47 @@ class DGLHeteroGraph(object): ...@@ -868,14 +988,47 @@ class DGLHeteroGraph(object):
To set features of all edges in a heterogeneous graph To set features of all edges in a heterogeneous graph
with only one edge type: with only one edge type:
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.edata['h'] = torch.zeros(2, 5) >>> g.edata['h'] = torch.zeros(2, 5)
To set features of all edges in a heterogeneous graph
with multiple edge types:
>>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> g = dgl.hetero_from_relations([g0, g1])
>>> g.edata['h'] = {('user', 'watches', 'movie') : torch.zeros(3, 5),
('user', 'watches', 'TV') : torch.zeros(2, 5)}
>>> g.edata['h']
... {('user', 'watches', 'movie'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]),
... ('user', 'watches', 'TV'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
To set features of part of edges in a heterogeneous graph
with multiple edge types:
>>> g0 = dgl.bipartite(([0, 1, 1], [1, 0, 1]), 'user', 'watches', 'movie')
>>> g1 = dgl.bipartite(([0, 1], [0, 1]), 'user', 'watches', 'TV')
>>> g = dgl.hetero_from_relations([g0, g1])
>>> g.edata['h'] = {('user', 'watches', 'movie') : torch.zeros(3, 5)}
>>> g.edata['h']
... {('user', 'watches', 'movie'): tensor([[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]])}
>>> # clean the feature 'h' and no edge type contains 'h'
>>> g.edata.pop('h')
>>> g.edata['h']
... {}
See Also See Also
-------- --------
edges edges
""" """
return HeteroEdgeDataView(self, None, ALL) if len(self.canonical_etypes) == 1:
return HeteroEdgeDataView(self, None, ALL)
else:
return HeteroEdgeDataView(self, self.canonical_etypes, ALL)
def _find_etypes(self, key): def _find_etypes(self, key):
etypes = [ etypes = [
...@@ -999,7 +1152,7 @@ class DGLHeteroGraph(object): ...@@ -999,7 +1152,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.number_of_nodes('user') >>> g.number_of_nodes('user')
3 3
>>> g.number_of_nodes() >>> g.number_of_nodes()
...@@ -1025,7 +1178,7 @@ class DGLHeteroGraph(object): ...@@ -1025,7 +1178,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.number_of_src_nodes('user') >>> g.number_of_src_nodes('user')
2 2
>>> g.number_of_src_nodes() >>> g.number_of_src_nodes()
...@@ -1053,7 +1206,7 @@ class DGLHeteroGraph(object): ...@@ -1053,7 +1206,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1], [1, 2]), 'user', 'plays', 'game')
>>> g.number_of_dst_nodes('game') >>> g.number_of_dst_nodes('game')
3 3
>>> g.number_of_dst_nodes() >>> g.number_of_dst_nodes()
...@@ -1080,8 +1233,9 @@ class DGLHeteroGraph(object): ...@@ -1080,8 +1233,9 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.number_of_edges(('user', 'follows', 'user')) >>> g.number_of_edges(('user', 'follows', 'user'))
2
>>> g.number_of_edges('follows') >>> g.number_of_edges('follows')
2 2
>>> g.number_of_edges() >>> g.number_of_edges()
...@@ -1290,8 +1444,8 @@ class DGLHeteroGraph(object): ...@@ -1290,8 +1444,8 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game') >>> devs_g = dgl.bipartite(([0, 1], [0, 1]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([plays_g, devs_g]) >>> g = dgl.hetero_from_relations([plays_g, devs_g])
>>> g.predecessors(0, 'plays') >>> g.predecessors(0, 'plays')
tensor([0, 1]) tensor([0, 1])
...@@ -1328,8 +1482,8 @@ class DGLHeteroGraph(object): ...@@ -1328,8 +1482,8 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> g.successors(0, 'plays') >>> g.successors(0, 'plays')
tensor([0]) tensor([0])
...@@ -1381,8 +1535,8 @@ class DGLHeteroGraph(object): ...@@ -1381,8 +1535,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for edge id. Query for edge id.
...@@ -1456,8 +1610,8 @@ class DGLHeteroGraph(object): ...@@ -1456,8 +1610,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for edge ids. Query for edge ids.
...@@ -1515,7 +1669,7 @@ class DGLHeteroGraph(object): ...@@ -1515,7 +1669,7 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.find_edges([0, 2], ('user', 'plays', 'game')) >>> g.find_edges([0, 2], ('user', 'plays', 'game'))
(tensor([0, 1]), tensor([0, 2])) (tensor([0, 1]), tensor([0, 2]))
>>> g.find_edges([0, 2]) >>> g.find_edges([0, 2])
...@@ -1572,7 +1726,7 @@ class DGLHeteroGraph(object): ...@@ -1572,7 +1726,7 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 1, 2]), 'user', 'plays', 'game')
>>> g.in_edges([0, 2], form='eid') >>> g.in_edges([0, 2], form='eid')
tensor([0, 2]) tensor([0, 2])
>>> g.in_edges([0, 2], form='all') >>> g.in_edges([0, 2], form='all')
...@@ -1624,7 +1778,7 @@ class DGLHeteroGraph(object): ...@@ -1624,7 +1778,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 0), (1, 1), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 1, 2]), 'user', 'plays', 'game')
>>> g.out_edges([0, 1], form='eid') >>> g.out_edges([0, 1], form='eid')
tensor([0, 1, 2]) tensor([0, 1, 2])
>>> g.out_edges([0, 1], form='all') >>> g.out_edges([0, 1], form='all')
...@@ -1680,7 +1834,7 @@ class DGLHeteroGraph(object): ...@@ -1680,7 +1834,7 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g = dgl.bipartite([(1, 1), (0, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([1, 0, 1], [1, 0, 2]), 'user', 'plays', 'game')
>>> g.all_edges(form='eid', order='srcdst') >>> g.all_edges(form='eid', order='srcdst')
tensor([1, 0, 2]) tensor([1, 0, 2])
>>> g.all_edges(form='all', order='srcdst') >>> g.all_edges(form='all', order='srcdst')
...@@ -1719,8 +1873,8 @@ class DGLHeteroGraph(object): ...@@ -1719,8 +1873,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree. Query for node degree.
...@@ -1760,8 +1914,8 @@ class DGLHeteroGraph(object): ...@@ -1760,8 +1914,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree. Query for node degree.
...@@ -1805,8 +1959,8 @@ class DGLHeteroGraph(object): ...@@ -1805,8 +1959,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree. Query for node degree.
...@@ -1846,8 +2000,8 @@ class DGLHeteroGraph(object): ...@@ -1846,8 +2000,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
Query for node degree. Query for node degree.
...@@ -1925,8 +2079,8 @@ class DGLHeteroGraph(object): ...@@ -1925,8 +2079,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features >>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
...@@ -2014,8 +2168,8 @@ class DGLHeteroGraph(object): ...@@ -2014,8 +2168,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features >>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
...@@ -2091,8 +2245,8 @@ class DGLHeteroGraph(object): ...@@ -2091,8 +2245,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set node features >>> # Set node features
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
...@@ -2172,8 +2326,8 @@ class DGLHeteroGraph(object): ...@@ -2172,8 +2326,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> follows_g = dgl.graph([(0, 1), (1, 2), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1, 1], [1, 2, 2]), 'user', 'follows')
>>> g = dgl.hetero_from_relations([plays_g, follows_g]) >>> g = dgl.hetero_from_relations([plays_g, follows_g])
>>> # Set edge features >>> # Set edge features
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [2.]])
...@@ -2256,8 +2410,8 @@ class DGLHeteroGraph(object): ...@@ -2256,8 +2410,8 @@ class DGLHeteroGraph(object):
Instantiate a heterogeneous graph. Instantiate a heterogeneous graph.
>>> follows_g = dgl.graph([(0, 0), (1, 1)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [0, 1]), 'user', 'follows')
>>> devs_g = dgl.bipartite([(0, 0), (1, 2)], 'developer', 'develops', 'game') >>> devs_g = dgl.bipartite(([0, 1], [0, 2]), 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, devs_g]) >>> g = dgl.hetero_from_relations([follows_g, devs_g])
Get a backend dependent sparse tensor. Here we use PyTorch for example. Get a backend dependent sparse tensor. Here we use PyTorch for example.
...@@ -2337,7 +2491,7 @@ class DGLHeteroGraph(object): ...@@ -2337,7 +2491,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [0, 2]), 'user', 'follows')
>>> g.incidence_matrix('in') >>> g.incidence_matrix('in')
tensor(indices=tensor([[0, 2], tensor(indices=tensor([[0, 2],
[0, 1]]), [0, 1]]),
...@@ -2386,7 +2540,7 @@ class DGLHeteroGraph(object): ...@@ -2386,7 +2540,7 @@ class DGLHeteroGraph(object):
-------- --------
The following uses PyTorch backend. The following uses PyTorch backend.
>>> g = dgl.graph([(0, 0), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [0, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.randn(3, 4) >>> g.nodes['user'].data['h'] = torch.randn(3, 4)
>>> g.node_attr_schemes('user') >>> g.node_attr_schemes('user')
{'h': Scheme(shape=(4,), dtype=torch.float32)} {'h': Scheme(shape=(4,), dtype=torch.float32)}
...@@ -2418,7 +2572,7 @@ class DGLHeteroGraph(object): ...@@ -2418,7 +2572,7 @@ class DGLHeteroGraph(object):
-------- --------
The following uses PyTorch backend. The following uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4) >>> g.edges['user', 'plays', 'game'].data['h'] = torch.randn(4, 4)
>>> g.edge_attr_schemes(('user', 'plays', 'game')) >>> g.edge_attr_schemes(('user', 'plays', 'game'))
{'h': Scheme(shape=(4,), dtype=torch.float32)} {'h': Scheme(shape=(4,), dtype=torch.float32)}
...@@ -2716,7 +2870,7 @@ class DGLHeteroGraph(object): ...@@ -2716,7 +2870,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.ones(3, 5) >>> g.nodes['user'].data['h'] = torch.ones(3, 5)
>>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user') >>> g.apply_nodes(lambda nodes: {'h': nodes.data['h'] * 2}, ntype='user')
>>> g.nodes['user'].data['h'] >>> g.nodes['user'].data['h']
...@@ -2762,7 +2916,7 @@ class DGLHeteroGraph(object): ...@@ -2762,7 +2916,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5) >>> g.edges[('user', 'plays', 'game')].data['h'] = torch.ones(4, 5)
>>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2}) >>> g.apply_edges(lambda edges: {'h': edges.data['h'] * 2})
>>> g.edges[('user', 'plays', 'game')].data['h'] >>> g.edges[('user', 'plays', 'game')].data['h']
...@@ -2824,7 +2978,7 @@ class DGLHeteroGraph(object): ...@@ -2824,7 +2978,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.graph([(0, 1), (0, 2), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 0, 1], [1, 2, 2]), 'user', 'follows')
>>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1)) >>> g.edata['feat'] = torch.randn((g.number_of_edges(), 1))
>>> def softmax_feat(edges): >>> def softmax_feat(edges):
>>> return {'norm_feat': th.softmax(edges.data['feat'], dim=1)} >>> return {'norm_feat': th.softmax(edges.data['feat'], dim=1)}
...@@ -2908,7 +3062,7 @@ class DGLHeteroGraph(object): ...@@ -2908,7 +3062,7 @@ class DGLHeteroGraph(object):
>>> import dgl.function as fn >>> import dgl.function as fn
>>> import torch >>> import torch
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Different ways for sending messages. Different ways for sending messages.
...@@ -3082,8 +3236,8 @@ class DGLHeteroGraph(object): ...@@ -3082,8 +3236,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1)], 'user', 'follows') >>> g1 = dgl.graph(([0], [1]), 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user') >>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2]) >>> g = dgl.hetero_from_relations([g1, g2])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
...@@ -3187,8 +3341,8 @@ class DGLHeteroGraph(object): ...@@ -3187,8 +3341,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (1, 0), (1, 1), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 1, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
Trigger "send" and "receive" separately. Trigger "send" and "receive" separately.
...@@ -3280,8 +3434,8 @@ class DGLHeteroGraph(object): ...@@ -3280,8 +3434,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1)], 'user', 'follows') >>> g1 = dgl.graph(([0], [1]), 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user') >>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2]) >>> g = dgl.hetero_from_relations([g1, g2])
Trigger send and recv separately. Trigger send and recv separately.
...@@ -3411,8 +3565,8 @@ class DGLHeteroGraph(object): ...@@ -3411,8 +3565,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows') >>> follows_g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (2, 1)], 'user', 'plays', 'game') >>> plays_g = dgl.bipartite(([0, 2], [0, 1]), 'user', 'plays', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g]) >>> g = dgl.hetero_from_relations([follows_g, plays_g])
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
...@@ -3482,8 +3636,8 @@ class DGLHeteroGraph(object): ...@@ -3482,8 +3636,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g1 = dgl.graph([(1, 1), (1, 0)], 'user', 'follows') >>> g1 = dgl.graph(([1, 1], [1, 0]), 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user') >>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2]) >>> g = dgl.hetero_from_relations([g1, g2])
Pull. Pull.
...@@ -3580,7 +3734,7 @@ class DGLHeteroGraph(object): ...@@ -3580,7 +3734,7 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g = dgl.graph([(0, 1), (0, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 0], [1, 2]), 'user', 'follows')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
Push. Push.
...@@ -3648,7 +3802,7 @@ class DGLHeteroGraph(object): ...@@ -3648,7 +3802,7 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g = dgl.graph([(0, 1), (1, 2), (2, 2)], 'user', 'follows') >>> g = dgl.graph(([0, 1, 2], [1, 2, 2]), 'user', 'follows')
Update all. Update all.
...@@ -3712,8 +3866,8 @@ class DGLHeteroGraph(object): ...@@ -3712,8 +3866,8 @@ class DGLHeteroGraph(object):
Instantiate a heterograph. Instantiate a heterograph.
>>> g1 = dgl.graph([(0, 1), (1, 1)], 'user', 'follows') >>> g1 = dgl.graph(([0, 1], [1, 1]), 'user', 'follows')
>>> g2 = dgl.bipartite([(0, 1)], 'game', 'attracts', 'user') >>> g2 = dgl.bipartite(([0], [1]), 'game', 'attracts', 'user')
>>> g = dgl.hetero_from_relations([g1, g2]) >>> g = dgl.hetero_from_relations([g1, g2])
>>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[1.]])
...@@ -4013,7 +4167,7 @@ class DGLHeteroGraph(object): ...@@ -4013,7 +4167,7 @@ class DGLHeteroGraph(object):
>>> import torch >>> import torch
>>> import dgl >>> import dgl
>>> import dgl.function as fn >>> import dgl.function as fn
>>> g = dgl.graph([(0, 0), (0, 1), (1, 2), (2, 3)], 'user', 'follows') >>> g = dgl.graph(([0, 0, 1, 2], [0, 1, 2, 3]), 'user', 'follows')
>>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) >>> g.edges['follows'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows') >>> g.filter_edges(lambda edges: (edges.data['h'] == 1.).squeeze(1), etype='follows')
tensor([1, 2]) tensor([1, 2])
...@@ -4055,7 +4209,7 @@ class DGLHeteroGraph(object): ...@@ -4055,7 +4209,7 @@ class DGLHeteroGraph(object):
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> print(g.device) >>> print(g.device)
device(type='cpu') device(type='cpu')
>>> g = g.to('cuda:0') >>> g = g.to('cuda:0')
...@@ -4088,7 +4242,7 @@ class DGLHeteroGraph(object): ...@@ -4088,7 +4242,7 @@ class DGLHeteroGraph(object):
The following example uses PyTorch backend. The following example uses PyTorch backend.
>>> import torch >>> import torch
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1, 2], [0, 0, 2, 1]), 'user', 'plays', 'game')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]]) >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
>>> g1 = g.to(torch.device('cuda:0')) >>> g1 = g.to(torch.device('cuda:0'))
...@@ -4146,7 +4300,7 @@ class DGLHeteroGraph(object): ...@@ -4146,7 +4300,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h'] >>> return g.edata['h']
>>> >>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
>>> newh = foo(g) # get tensor of all ones >>> newh = foo(g) # get tensor of all ones
>>> print(g.edata['h']) # still get tensor of all zeros >>> print(g.edata['h']) # still get tensor of all zeros
...@@ -4160,7 +4314,7 @@ class DGLHeteroGraph(object): ...@@ -4160,7 +4314,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h'] >>> return g.edata['h']
>>> >>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> h = foo(g) >>> h = foo(g)
>>> print('h' in g.edata) >>> print('h' in g.edata)
False False
...@@ -4198,7 +4352,7 @@ class DGLHeteroGraph(object): ...@@ -4198,7 +4352,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h'] >>> return g.edata['h']
>>> >>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.zeros((g.number_of_edges(), 3))
>>> newh = foo(g) # get tensor of all ones >>> newh = foo(g) # get tensor of all ones
>>> print(g.edata['h']) # still get tensor of all zeros >>> print(g.edata['h']) # still get tensor of all zeros
...@@ -4212,7 +4366,7 @@ class DGLHeteroGraph(object): ...@@ -4212,7 +4366,7 @@ class DGLHeteroGraph(object):
>>> g.edata['h'] = torch.ones((g.number_of_edges(), 3)) >>> g.edata['h'] = torch.ones((g.number_of_edges(), 3))
>>> return g.edata['h'] >>> return g.edata['h']
>>> >>>
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game') >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game')
>>> h = foo(g) >>> h = foo(g)
>>> print('h' in g.edata) >>> print('h' in g.edata)
False False
...@@ -4251,15 +4405,15 @@ class DGLHeteroGraph(object): ...@@ -4251,15 +4405,15 @@ class DGLHeteroGraph(object):
-------- --------
For graph with only one edge type. For graph with only one edge type.
>>> g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', restrict_format='csr') >>> g = dgl.graph(([0, 1], [1, 2]), 'user', 'follows', restrict_format='csr')
>>> g.format_in_use() >>> g.format_in_use()
['csr'] ['csr']
For a graph with multiple types. For a graph with multiple types.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)], ... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any') ... }, restrict_format='any')
>>> g.format_in_use('develops') >>> g.format_in_use('develops')
['coo'] ['coo']
...@@ -4306,8 +4460,8 @@ class DGLHeteroGraph(object): ...@@ -4306,8 +4460,8 @@ class DGLHeteroGraph(object):
For a graph with multiple types. For a graph with multiple types.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)], ... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any') ... }, restrict_format='any')
>>> g.restrict_format('develops') >>> g.restrict_format('develops')
'any' 'any'
...@@ -4354,8 +4508,8 @@ class DGLHeteroGraph(object): ...@@ -4354,8 +4508,8 @@ class DGLHeteroGraph(object):
For a graph with multiple types. For a graph with multiple types.
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)], ... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='any') ... }, restrict_format='any')
>>> g.format_in_use('develops') >>> g.format_in_use('develops')
['coo'] ['coo']
...@@ -4416,8 +4570,8 @@ class DGLHeteroGraph(object): ...@@ -4416,8 +4570,8 @@ class DGLHeteroGraph(object):
For a graph with multiple edge types: For a graph with multiple edge types:
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (1, 1), (2, 1)], ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
... ('developer', 'develops', 'game'): [(0, 0), (1, 1)], ... ('developer', 'develops', 'game'): ([0, 1], [0, 1]),
... }, restrict_format='coo') ... }, restrict_format='coo')
>>> g.restrict_format('develops') >>> g.restrict_format('develops')
'coo' 'coo'
...@@ -4447,7 +4601,7 @@ class DGLHeteroGraph(object): ...@@ -4447,7 +4601,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game', >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game',
>>> index_dtype='int32') >>> index_dtype='int32')
>>> g_long = g.long() # Convert g to int64 indexed, not changing the original `g` >>> g_long = g.long() # Convert g to int64 indexed, not changing the original `g`
...@@ -4472,7 +4626,7 @@ class DGLHeteroGraph(object): ...@@ -4472,7 +4626,7 @@ class DGLHeteroGraph(object):
Examples Examples
-------- --------
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2)], 'user', 'plays', 'game', >>> g = dgl.bipartite(([0, 1, 1], [0, 0, 2]), 'user', 'plays', 'game',
>>> index_dtype='int64') >>> index_dtype='int64')
>>> g_int = g.int() # Convert g to int32 indexed, not changing the original `g` >>> g_int = g.int() # Convert g to int32 indexed, not changing the original `g`
......
...@@ -291,24 +291,65 @@ class HeteroNodeDataView(MutableMapping): ...@@ -291,24 +291,65 @@ class HeteroNodeDataView(MutableMapping):
self._nodes = nodes self._nodes = nodes
def __getitem__(self, key): def __getitem__(self, key):
return self._graph._get_n_repr(self._ntid, self._nodes)[key] if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(key, None)
if value is not None:
ret[ntype] = value
return ret
else:
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph._set_n_repr(self._ntid, self._nodes, {key : val}) if isinstance(self._ntype, list):
assert isinstance(val, dict), \
'Current HeteroNodeDataView has multiple node types, ' \
'please passing the node type and the corresponding data through a dict.'
for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly'
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph._pop_n_repr(self._ntid, key) if isinstance(self._ntype, list):
for ntid in self._ntid:
if self._graph._get_n_repr(ntid, ALL).get(key, None) is None:
continue
self._graph._pop_n_repr(ntid, key)
else:
self._graph._pop_n_repr(self._ntid, key)
def __len__(self): def __len__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not support len().'
return len(self._graph._node_frames[self._ntid]) return len(self._graph._node_frames[self._ntid])
def __iter__(self): def __iter__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not be iterated.'
return iter(self._graph._node_frames[self._ntid]) return iter(self._graph._node_frames[self._ntid])
def __repr__(self): def __repr__(self):
data = self._graph._get_n_repr(self._ntid, self._nodes) if isinstance(self._ntype, list):
return repr({key : data[key] ret = {}
for key in self._graph._node_frames[self._ntid]}) for (i, ntype) in enumerate(self._ntype):
data = self._graph._get_n_repr(self._ntid[i], self._nodes)
value = {key : data[key]
for key in self._graph._node_frames[self._ntid[i]]}
ret[ntype] = value
return repr(ret)
else:
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
...@@ -345,31 +386,74 @@ class HeteroEdgeView(object): ...@@ -345,31 +386,74 @@ class HeteroEdgeView(object):
return self._graph.all_edges(*args, **kwargs) return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeDataView(MutableMapping): class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.ndata[etype] is called.""" """The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges'] __slots__ = ['_graph', '_etype', '_etid', '_edges']
def __init__(self, graph, etype, edges): def __init__(self, graph, etype, edges):
self._graph = graph self._graph = graph
self._etype = etype self._etype = etype
self._etid = self._graph.get_etype_id(etype) self._etid = [self._graph.get_etype_id(t) for t in etype] \
if isinstance(etype, list) \
else self._graph.get_etype_id(etype)
self._edges = edges self._edges = edges
def __getitem__(self, key): def __getitem__(self, key):
return self._graph._get_e_repr(self._etid, self._edges)[key] if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(self._etid[i], self._edges).get(key, None)
if value is not None:
ret[etype] = value
return ret
else:
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph._set_e_repr(self._etid, self._edges, {key : val}) if isinstance(self._etype, list):
assert isinstance(val, dict), \
'Current HeteroEdgeDataView has multiple edge types, ' \
'please pass the edge type and the corresponding data through a dict.'
for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly'
self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph._pop_e_repr(self._etid, key) if isinstance(self._etype, list):
for etid in self._etid:
if self._graph._get_e_repr(etid, ALL).get(key, None) is None:
continue
self._graph._pop_e_repr(etid, key)
else:
self._graph._pop_e_repr(self._etid, key)
def __len__(self): def __len__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not support len().'
return len(self._graph._edge_frames[self._etid]) return len(self._graph._edge_frames[self._etid])
def __iter__(self): def __iter__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not be iterated.'
return iter(self._graph._edge_frames[self._etid]) return iter(self._graph._edge_frames[self._etid])
def __repr__(self): def __repr__(self):
data = self._graph._get_e_repr(self._etid, self._edges) if isinstance(self._etype, list):
return repr({key : data[key] ret = {}
for key in self._graph._edge_frames[self._etid]}) for (i, etype) in enumerate(self._etype):
data = self._graph._get_e_repr(self._etid[i], self._edges)
value = {key : data[key]
for key in self._graph._edge_frames[self._etid[i]]}
ret[etype] = value
return repr(ret)
else:
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._etid]})
...@@ -494,6 +494,31 @@ def test_inc(index_dtype): ...@@ -494,6 +494,31 @@ def test_inc(index_dtype):
@parametrize_dtype @parametrize_dtype
def test_view(index_dtype): def test_view(index_dtype):
# test single node type
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
f1 = F.randn((3, 6))
g.ndata['h'] = f1
f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2)
fail = False
try:
g.ndata['h'] = {'user' : f1}
except Exception:
fail = True
assert fail
# test single edge type
f3 = F.randn((2, 4))
g.edata['h'] = f3
f4 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
fail = False
try:
g.edata['h'] = {'follows' : f3}
except Exception:
fail = True
assert fail
# test data view # test data view
g = create_test_heterograph(index_dtype) g = create_test_heterograph(index_dtype)
...@@ -502,6 +527,31 @@ def test_view(index_dtype): ...@@ -502,6 +527,31 @@ def test_view(index_dtype):
f2 = g.nodes['user'].data['h'] f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2) assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.nodes('user')), F.arange(0, 3)) assert F.array_equal(F.tensor(g.nodes('user')), F.arange(0, 3))
g.nodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.ndata['h'] = f1
except Exception:
fail = True
assert fail
g.ndata['h'] = {'user' : f1,
'game' : f2}
f3 = g.nodes['user'].data['h']
f4 = g.nodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.ndata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.ndata)
g.ndata.pop('h')
# test repr
print(g.ndata)
f3 = F.randn((2, 4)) f3 = F.randn((2, 4))
g.edges['user', 'follows', 'user'].data['h'] = f3 g.edges['user', 'follows', 'user'].data['h'] = f3
...@@ -510,6 +560,87 @@ def test_view(index_dtype): ...@@ -510,6 +560,87 @@ def test_view(index_dtype):
assert F.array_equal(f3, f4) assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5) assert F.array_equal(f3, f5)
assert F.array_equal(F.tensor(g.edges(etype='follows', form='eid')), F.arange(0, 2)) assert F.array_equal(F.tensor(g.edges(etype='follows', form='eid')), F.arange(0, 2))
g.edges['follows'].data.pop('h')
f3 = F.randn((2, 4))
fail = False
try:
g.edata['h'] = f3
except Exception:
fail = True
assert fail
g.edata['h'] = {('user', 'follows', 'user') : f3}
f4 = g.edges['user', 'follows', 'user'].data['h']
f5 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5)
data = g.edata['h']
assert F.array_equal(f3, data[('user', 'follows', 'user')])
# test repr
print(g.edata)
g.edata.pop('h')
# test repr
print(g.edata)
# test srcdata
f1 = F.randn((3, 6))
g.srcnodes['user'].data['h'] = f1 # ok
f2 = g.srcnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.srcnodes('user')), F.arange(0, 3))
g.srcnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.srcdata['h'] = f1
except Exception:
fail = True
assert fail
g.srcdata['h'] = {'user' : f1,
'developer' : f2}
f3 = g.srcnodes['user'].data['h']
f4 = g.srcnodes['developer'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.srcdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['developer'])
# test repr
print(g.srcdata)
g.srcdata.pop('h')
# test dstdata
f1 = F.randn((3, 6))
g.dstnodes['user'].data['h'] = f1 # ok
f2 = g.dstnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.dstnodes('user')), F.arange(0, 3))
g.dstnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.dstdata['h'] = f1
except Exception:
fail = True
assert fail
g.dstdata['h'] = {'user' : f1,
'game' : f2}
f3 = g.dstnodes['user'].data['h']
f4 = g.dstnodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.dstdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.dstdata)
g.dstdata.pop('h')
@parametrize_dtype @parametrize_dtype
def test_view1(index_dtype): def test_view1(index_dtype):
...@@ -639,21 +770,14 @@ def test_view1(index_dtype): ...@@ -639,21 +770,14 @@ def test_view1(index_dtype):
assert F.array_equal(f3, f4) assert F.array_equal(f3, f4)
assert F.array_equal(F.tensor(g.edges(form='eid')), F.arange(0, 2)) assert F.array_equal(F.tensor(g.edges(form='eid')), F.arange(0, 2))
# test fail case # multiple types
# fail due to multiple types ndata = HG.ndata['h']
fail = False assert isinstance(ndata, dict)
try: assert F.array_equal(ndata['user'], f2)
HG.ndata['h']
except dgl.DGLError: edata = HG.edata['h']
fail = True assert isinstance(edata, dict)
assert fail assert F.array_equal(edata[('user', 'follows', 'user')], f4)
fail = False
try:
HG.edata['h']
except dgl.DGLError:
fail = True
assert fail
@parametrize_dtype @parametrize_dtype
def test_flatten(index_dtype): def test_flatten(index_dtype):
...@@ -1740,7 +1864,7 @@ if __name__ == '__main__': ...@@ -1740,7 +1864,7 @@ if __name__ == '__main__':
# test_hypersparse() # test_hypersparse()
# test_adj("int32") # test_adj("int32")
# test_inc() # test_inc()
# test_view() # test_view("int32")
# test_view1("int32") # test_view1("int32")
# test_flatten() # test_flatten()
# test_convert_bound() # test_convert_bound()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment