Unverified Commit ef93518d authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Refactor] Auto fix view.py. (#4461)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ab0af814
...@@ -11,6 +11,7 @@ from .frame import LazyFeature ...@@ -11,6 +11,7 @@ from .frame import LazyFeature
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data']) EdgeSpace = namedtuple('EdgeSpace', ['data'])
class HeteroNodeView(object): class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph.""" """A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph', '_typeid_getter'] __slots__ = ['_graph', '_typeid_getter']
...@@ -36,7 +37,9 @@ class HeteroNodeView(object): ...@@ -36,7 +37,9 @@ class HeteroNodeView(object):
nodes = key nodes = key
ntype = None ntype = None
ntid = self._typeid_getter(ntype) ntid = self._typeid_getter(ntype)
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, ntid, nodes)) return NodeSpace(
data=HeteroNodeDataView(
self._graph, ntype, ntid, nodes))
def __call__(self, ntype=None): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
...@@ -45,6 +48,7 @@ class HeteroNodeView(object): ...@@ -45,6 +48,7 @@ class HeteroNodeView(object):
dtype=self._graph.idtype, ctx=self._graph.device) dtype=self._graph.idtype, ctx=self._graph.device)
return ret return ret
class HeteroNodeDataView(MutableMapping): class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called.""" """The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes'] __slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
...@@ -59,7 +63,9 @@ class HeteroNodeDataView(MutableMapping): ...@@ -59,7 +63,9 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(self._ntype, list): if isinstance(self._ntype, list):
ret = {} ret = {}
for (i, ntype) in enumerate(self._ntype): for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(key, None) value = self._graph._get_n_repr(
self._ntid[i], self._nodes).get(
key, None)
if value is not None: if value is not None:
ret[ntype] = value ret[ntype] = value
return ret return ret
...@@ -76,12 +82,12 @@ class HeteroNodeDataView(MutableMapping): ...@@ -76,12 +82,12 @@ class HeteroNodeDataView(MutableMapping):
for (ntype, data) in val.items(): for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype) ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key : data}) self._graph._set_n_repr(ntid, self._nodes, {key: data})
else: else:
assert isinstance(val, dict) is False, \ assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \ 'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly' 'please pass a tensor directly'
self._graph._set_n_repr(self._ntid, self._nodes, {key : val}) self._graph._set_n_repr(self._ntid, self._nodes, {key: val})
def __delitem__(self, key): def __delitem__(self, key):
if isinstance(self._ntype, list): if isinstance(self._ntype, list):
...@@ -102,7 +108,8 @@ class HeteroNodeDataView(MutableMapping): ...@@ -102,7 +108,8 @@ class HeteroNodeDataView(MutableMapping):
else: else:
ret = self._graph._get_n_repr(self._ntid, self._nodes) ret = self._graph._get_n_repr(self._ntid, self._nodes)
if as_dict: if as_dict:
ret = {key: ret[key] for key in self._graph._node_frames[self._ntid]} ret = {key: ret[key]
for key in self._graph._node_frames[self._ntid]}
return ret return ret
def __len__(self): def __len__(self):
...@@ -120,6 +127,7 @@ class HeteroNodeDataView(MutableMapping): ...@@ -120,6 +127,7 @@ class HeteroNodeDataView(MutableMapping):
def __repr__(self): def __repr__(self):
return repr(self._transpose(as_dict=True)) return repr(self._transpose(as_dict=True))
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph'] __slots__ = ['_graph']
...@@ -157,6 +165,7 @@ class HeteroEdgeView(object): ...@@ -157,6 +165,7 @@ class HeteroEdgeView(object):
"""Return all the edges.""" """Return all the edges."""
return self._graph.all_edges(*args, **kwargs) return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeDataView(MutableMapping): class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.edata[etype] is called.""" """The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges'] __slots__ = ['_graph', '_etype', '_etid', '_edges']
...@@ -173,7 +182,9 @@ class HeteroEdgeDataView(MutableMapping): ...@@ -173,7 +182,9 @@ class HeteroEdgeDataView(MutableMapping):
if isinstance(self._etype, list): if isinstance(self._etype, list):
ret = {} ret = {}
for (i, etype) in enumerate(self._etype): for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(self._etid[i], self._edges).get(key, None) value = self._graph._get_e_repr(
self._etid[i], self._edges).get(
key, None)
if value is not None: if value is not None:
ret[etype] = value ret[etype] = value
return ret return ret
...@@ -190,12 +201,12 @@ class HeteroEdgeDataView(MutableMapping): ...@@ -190,12 +201,12 @@ class HeteroEdgeDataView(MutableMapping):
for (etype, data) in val.items(): for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype) etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key : data}) self._graph._set_e_repr(etid, self._edges, {key: data})
else: else:
assert isinstance(val, dict) is False, \ assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \ 'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly' 'please pass a tensor directly'
self._graph._set_e_repr(self._etid, self._edges, {key : val}) self._graph._set_e_repr(self._etid, self._edges, {key: val})
def __delitem__(self, key): def __delitem__(self, key):
if isinstance(self._etype, list): if isinstance(self._etype, list):
...@@ -216,7 +227,8 @@ class HeteroEdgeDataView(MutableMapping): ...@@ -216,7 +227,8 @@ class HeteroEdgeDataView(MutableMapping):
else: else:
ret = self._graph._get_e_repr(self._etid, self._edges) ret = self._graph._get_e_repr(self._etid, self._edges)
if as_dict: if as_dict:
ret = {key: ret[key] for key in self._graph._edge_frames[self._etid]} ret = {key: ret[key]
for key in self._graph._edge_frames[self._etid]}
return ret return ret
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment