"...text-generation-inference.git" did not exist on "c2fd35d875155d858a60542edabb9df59587e1f8"
Unverified Commit ef93518d authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Refactor] Auto fix view.py. (#4461)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ab0af814
......@@ -11,6 +11,7 @@ from .frame import LazyFeature
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph', '_typeid_getter']
......@@ -36,7 +37,9 @@ class HeteroNodeView(object):
nodes = key
ntype = None
ntid = self._typeid_getter(ntype)
return NodeSpace(data=HeteroNodeDataView(self._graph, ntype, ntid, nodes))
return NodeSpace(
data=HeteroNodeDataView(
self._graph, ntype, ntid, nodes))
def __call__(self, ntype=None):
"""Return the nodes."""
......@@ -45,6 +48,7 @@ class HeteroNodeView(object):
dtype=self._graph.idtype, ctx=self._graph.device)
return ret
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
......@@ -59,7 +63,9 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(key, None)
value = self._graph._get_n_repr(
self._ntid[i], self._nodes).get(
key, None)
if value is not None:
ret[ntype] = value
return ret
......@@ -76,12 +82,12 @@ class HeteroNodeDataView(MutableMapping):
for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key : data})
self._graph._set_n_repr(ntid, self._nodes, {key: data})
else:
assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly'
self._graph._set_n_repr(self._ntid, self._nodes, {key : val})
self._graph._set_n_repr(self._ntid, self._nodes, {key: val})
def __delitem__(self, key):
if isinstance(self._ntype, list):
......@@ -102,7 +108,8 @@ class HeteroNodeDataView(MutableMapping):
else:
ret = self._graph._get_n_repr(self._ntid, self._nodes)
if as_dict:
ret = {key: ret[key] for key in self._graph._node_frames[self._ntid]}
ret = {key: ret[key]
for key in self._graph._node_frames[self._ntid]}
return ret
def __len__(self):
......@@ -120,6 +127,7 @@ class HeteroNodeDataView(MutableMapping):
def __repr__(self):
return repr(self._transpose(as_dict=True))
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph']
......@@ -157,6 +165,7 @@ class HeteroEdgeView(object):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
......@@ -165,15 +174,17 @@ class HeteroEdgeDataView(MutableMapping):
self._graph = graph
self._etype = etype
self._etid = [self._graph.get_etype_id(t) for t in etype] \
if isinstance(etype, list) \
else self._graph.get_etype_id(etype)
if isinstance(etype, list) \
else self._graph.get_etype_id(etype)
self._edges = edges
def __getitem__(self, key):
if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(self._etid[i], self._edges).get(key, None)
value = self._graph._get_e_repr(
self._etid[i], self._edges).get(
key, None)
if value is not None:
ret[etype] = value
return ret
......@@ -190,12 +201,12 @@ class HeteroEdgeDataView(MutableMapping):
for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key : data})
self._graph._set_e_repr(etid, self._edges, {key: data})
else:
assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly'
self._graph._set_e_repr(self._etid, self._edges, {key : val})
self._graph._set_e_repr(self._etid, self._edges, {key: val})
def __delitem__(self, key):
if isinstance(self._etype, list):
......@@ -216,7 +227,8 @@ class HeteroEdgeDataView(MutableMapping):
else:
ret = self._graph._get_e_repr(self._etid, self._edges)
if as_dict:
ret = {key: ret[key] for key in self._graph._edge_frames[self._etid]}
ret = {key: ret[key]
for key in self._graph._edge_frames[self._etid]}
return ret
def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment