Unverified Commit b742c559 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature]heterograph.edata and heterograph.ndata support multiple ntype and etype (#1673)



* g.edata and g.ndata support multiple ntype and etype

* Add test case

* support g.srcdata and g.dstdata

* Fix test

* lint

* Fix

* Fix some comments

* lint
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 8d16d4bb
This diff is collapsed.
...@@ -291,24 +291,65 @@ class HeteroNodeDataView(MutableMapping): ...@@ -291,24 +291,65 @@ class HeteroNodeDataView(MutableMapping):
self._nodes = nodes self._nodes = nodes
def __getitem__(self, key): def __getitem__(self, key):
return self._graph._get_n_repr(self._ntid, self._nodes)[key] if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(key, None)
if value is not None:
ret[ntype] = value
return ret
else:
return self._graph._get_n_repr(self._ntid, self._nodes)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph._set_n_repr(self._ntid, self._nodes, {key : val}) if isinstance(self._ntype, list):
assert isinstance(val, dict), \
'Current HeteroNodeDataView has multiple node types, ' \
'please passing the node type and the corresponding data through a dict.'
for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly'
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph._pop_n_repr(self._ntid, key) if isinstance(self._ntype, list):
for ntid in self._ntid:
if self._graph._get_n_repr(ntid, ALL).get(key, None) is None:
continue
self._graph._pop_n_repr(ntid, key)
else:
self._graph._pop_n_repr(self._ntid, key)
def __len__(self): def __len__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not support len().'
return len(self._graph._node_frames[self._ntid]) return len(self._graph._node_frames[self._ntid])
def __iter__(self): def __iter__(self):
assert isinstance(self._ntype, list) is False, \
'Current HeteroNodeDataView has multiple node types, ' \
'can not be iterated.'
return iter(self._graph._node_frames[self._ntid]) return iter(self._graph._node_frames[self._ntid])
def __repr__(self): def __repr__(self):
data = self._graph._get_n_repr(self._ntid, self._nodes) if isinstance(self._ntype, list):
return repr({key : data[key] ret = {}
for key in self._graph._node_frames[self._ntid]}) for (i, ntype) in enumerate(self._ntype):
data = self._graph._get_n_repr(self._ntid[i], self._nodes)
value = {key : data[key]
for key in self._graph._node_frames[self._ntid[i]]}
ret[ntype] = value
return repr(ret)
else:
data = self._graph._get_n_repr(self._ntid, self._nodes)
return repr({key : data[key]
for key in self._graph._node_frames[self._ntid]})
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
...@@ -345,31 +386,74 @@ class HeteroEdgeView(object): ...@@ -345,31 +386,74 @@ class HeteroEdgeView(object):
return self._graph.all_edges(*args, **kwargs) return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeDataView(MutableMapping): class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.ndata[etype] is called.""" """The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges'] __slots__ = ['_graph', '_etype', '_etid', '_edges']
def __init__(self, graph, etype, edges): def __init__(self, graph, etype, edges):
self._graph = graph self._graph = graph
self._etype = etype self._etype = etype
self._etid = self._graph.get_etype_id(etype) self._etid = [self._graph.get_etype_id(t) for t in etype] \
if isinstance(etype, list) \
else self._graph.get_etype_id(etype)
self._edges = edges self._edges = edges
def __getitem__(self, key): def __getitem__(self, key):
return self._graph._get_e_repr(self._etid, self._edges)[key] if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(self._etid[i], self._edges).get(key, None)
if value is not None:
ret[etype] = value
return ret
else:
return self._graph._get_e_repr(self._etid, self._edges)[key]
def __setitem__(self, key, val): def __setitem__(self, key, val):
self._graph._set_e_repr(self._etid, self._edges, {key : val}) if isinstance(self._etype, list):
assert isinstance(val, dict), \
'Current HeteroEdgeDataView has multiple edge types, ' \
'please pass the edge type and the corresponding data through a dict.'
for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key : data})
else:
assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly'
self._graph._set_e_repr(self._etid, self._edges, {key : val})
def __delitem__(self, key): def __delitem__(self, key):
self._graph._pop_e_repr(self._etid, key) if isinstance(self._etype, list):
for etid in self._etid:
if self._graph._get_e_repr(etid, ALL).get(key, None) is None:
continue
self._graph._pop_e_repr(etid, key)
else:
self._graph._pop_e_repr(self._etid, key)
def __len__(self): def __len__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not support len().'
return len(self._graph._edge_frames[self._etid]) return len(self._graph._edge_frames[self._etid])
def __iter__(self): def __iter__(self):
assert isinstance(self._etype, list) is False, \
'Current HeteroEdgeDataView has multiple edge types, ' \
'can not be iterated.'
return iter(self._graph._edge_frames[self._etid]) return iter(self._graph._edge_frames[self._etid])
def __repr__(self): def __repr__(self):
data = self._graph._get_e_repr(self._etid, self._edges) if isinstance(self._etype, list):
return repr({key : data[key] ret = {}
for key in self._graph._edge_frames[self._etid]}) for (i, etype) in enumerate(self._etype):
data = self._graph._get_e_repr(self._etid[i], self._edges)
value = {key : data[key]
for key in self._graph._edge_frames[self._etid[i]]}
ret[etype] = value
return repr(ret)
else:
data = self._graph._get_e_repr(self._etid, self._edges)
return repr({key : data[key]
for key in self._graph._edge_frames[self._etid]})
...@@ -494,6 +494,31 @@ def test_inc(index_dtype): ...@@ -494,6 +494,31 @@ def test_inc(index_dtype):
@parametrize_dtype @parametrize_dtype
def test_view(index_dtype): def test_view(index_dtype):
# test single node type
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
f1 = F.randn((3, 6))
g.ndata['h'] = f1
f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2)
fail = False
try:
g.ndata['h'] = {'user' : f1}
except Exception:
fail = True
assert fail
# test single edge type
f3 = F.randn((2, 4))
g.edata['h'] = f3
f4 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
fail = False
try:
g.edata['h'] = {'follows' : f3}
except Exception:
fail = True
assert fail
# test data view # test data view
g = create_test_heterograph(index_dtype) g = create_test_heterograph(index_dtype)
...@@ -502,6 +527,31 @@ def test_view(index_dtype): ...@@ -502,6 +527,31 @@ def test_view(index_dtype):
f2 = g.nodes['user'].data['h'] f2 = g.nodes['user'].data['h']
assert F.array_equal(f1, f2) assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.nodes('user')), F.arange(0, 3)) assert F.array_equal(F.tensor(g.nodes('user')), F.arange(0, 3))
g.nodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.ndata['h'] = f1
except Exception:
fail = True
assert fail
g.ndata['h'] = {'user' : f1,
'game' : f2}
f3 = g.nodes['user'].data['h']
f4 = g.nodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.ndata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.ndata)
g.ndata.pop('h')
# test repr
print(g.ndata)
f3 = F.randn((2, 4)) f3 = F.randn((2, 4))
g.edges['user', 'follows', 'user'].data['h'] = f3 g.edges['user', 'follows', 'user'].data['h'] = f3
...@@ -510,6 +560,87 @@ def test_view(index_dtype): ...@@ -510,6 +560,87 @@ def test_view(index_dtype):
assert F.array_equal(f3, f4) assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5) assert F.array_equal(f3, f5)
assert F.array_equal(F.tensor(g.edges(etype='follows', form='eid')), F.arange(0, 2)) assert F.array_equal(F.tensor(g.edges(etype='follows', form='eid')), F.arange(0, 2))
g.edges['follows'].data.pop('h')
f3 = F.randn((2, 4))
fail = False
try:
g.edata['h'] = f3
except Exception:
fail = True
assert fail
g.edata['h'] = {('user', 'follows', 'user') : f3}
f4 = g.edges['user', 'follows', 'user'].data['h']
f5 = g.edges['follows'].data['h']
assert F.array_equal(f3, f4)
assert F.array_equal(f3, f5)
data = g.edata['h']
assert F.array_equal(f3, data[('user', 'follows', 'user')])
# test repr
print(g.edata)
g.edata.pop('h')
# test repr
print(g.edata)
# test srcdata
f1 = F.randn((3, 6))
g.srcnodes['user'].data['h'] = f1 # ok
f2 = g.srcnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.srcnodes('user')), F.arange(0, 3))
g.srcnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.srcdata['h'] = f1
except Exception:
fail = True
assert fail
g.srcdata['h'] = {'user' : f1,
'developer' : f2}
f3 = g.srcnodes['user'].data['h']
f4 = g.srcnodes['developer'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.srcdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['developer'])
# test repr
print(g.srcdata)
g.srcdata.pop('h')
# test dstdata
f1 = F.randn((3, 6))
g.dstnodes['user'].data['h'] = f1 # ok
f2 = g.dstnodes['user'].data['h']
assert F.array_equal(f1, f2)
assert F.array_equal(F.tensor(g.dstnodes('user')), F.arange(0, 3))
g.dstnodes['user'].data.pop('h')
# multi type ndata
f1 = F.randn((3, 6))
f2 = F.randn((2, 6))
fail = False
try:
g.dstdata['h'] = f1
except Exception:
fail = True
assert fail
g.dstdata['h'] = {'user' : f1,
'game' : f2}
f3 = g.dstnodes['user'].data['h']
f4 = g.dstnodes['game'].data['h']
assert F.array_equal(f1, f3)
assert F.array_equal(f2, f4)
data = g.dstdata['h']
assert F.array_equal(f1, data['user'])
assert F.array_equal(f2, data['game'])
# test repr
print(g.dstdata)
g.dstdata.pop('h')
@parametrize_dtype @parametrize_dtype
def test_view1(index_dtype): def test_view1(index_dtype):
...@@ -639,21 +770,14 @@ def test_view1(index_dtype): ...@@ -639,21 +770,14 @@ def test_view1(index_dtype):
assert F.array_equal(f3, f4) assert F.array_equal(f3, f4)
assert F.array_equal(F.tensor(g.edges(form='eid')), F.arange(0, 2)) assert F.array_equal(F.tensor(g.edges(form='eid')), F.arange(0, 2))
# test fail case # multiple types
# fail due to multiple types ndata = HG.ndata['h']
fail = False assert isinstance(ndata, dict)
try: assert F.array_equal(ndata['user'], f2)
HG.ndata['h']
except dgl.DGLError: edata = HG.edata['h']
fail = True assert isinstance(edata, dict)
assert fail assert F.array_equal(edata[('user', 'follows', 'user')], f4)
fail = False
try:
HG.edata['h']
except dgl.DGLError:
fail = True
assert fail
@parametrize_dtype @parametrize_dtype
def test_flatten(index_dtype): def test_flatten(index_dtype):
...@@ -1740,7 +1864,7 @@ if __name__ == '__main__': ...@@ -1740,7 +1864,7 @@ if __name__ == '__main__':
# test_hypersparse() # test_hypersparse()
# test_adj("int32") # test_adj("int32")
# test_inc() # test_inc()
# test_view() # test_view("int32")
# test_view1("int32") # test_view1("int32")
# test_flatten() # test_flatten()
# test_convert_bound() # test_convert_bound()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment