Commit b0e02e5b authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

Update the subgraph (#73)

* update subgraph.

* update subgraph API.

* keep node embedding.
parent 2883eda6
......@@ -141,7 +141,7 @@ class FrameRef(MutableMapping):
else:
self.update_rows(key, val)
def add_column(self, name, col):
def add_column(self, name, col, inplace=False):
shp = F.shape(col)
if self.is_span_whole_column():
if self.num_columns == 0:
......@@ -157,18 +157,25 @@ class FrameRef(MutableMapping):
fcol = F.zeros((self._frame.num_rows,) + shp[1:])
fcol = F.to_context(fcol, colctx)
idx = self.index().tousertensor(colctx)
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
if inplace:
self._frame[name] = fcol
self._frame[name][idx] = col
else:
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
def update_rows(self, query, other):
def update_rows(self, query, other, inplace=False):
rowids = self._getrowid(query)
for key, col in other.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col)
tmpref.add_column(key, col, inplace)
idx = rowids.tousertensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
if inplace:
self._frame[key][idx] = col
else:
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key):
if isinstance(key, str):
......
......@@ -486,7 +486,7 @@ class DGLGraph(object):
"""
return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL):
def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or
......@@ -524,9 +524,9 @@ class DGLGraph(object):
self._node_frame[__REPR__] = hu
else:
if utils.is_dict_like(hu):
self._node_frame[u] = hu
self._node_frame.update_rows(u, hu, inplace=inplace)
else:
self._node_frame[u] = {__REPR__ : hu}
self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL):
"""Get node(s) representation.
......
......@@ -15,28 +15,21 @@ class DGLSubGraph(DGLGraph):
nodes):
super(DGLSubGraph, self).__init__()
# relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes)
self._parent = parent
self._parent_nid = utils.toindex(nodes)
eids = []
# create subgraph
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
self._graph, self._parent_eid = parent._graph.node_subgraph(self._parent_nid)
self.reset_messages()
def copy_from(self, parent):
def copy_to_parent(self, inplace=False):
self._parent._node_frame.update_rows(self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(self._parent_eid, self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph.
All old features will be removed.
Parameters
----------
parent : DGLGraph
The parent graph to copy from.
"""
if parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid]))
if parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid]))
if self._parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(self._parent._node_frame[self._parent_nid]))
if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(self._parent._edge_frame[self._parent_eid]))
......@@ -35,7 +35,7 @@ def test_basics():
assert len(sg.get_n_repr()) == 0
assert len(sg.get_e_repr()) == 0
# the data is copied after explict copy from
sg.copy_from(g)
sg.copy_from_parent()
assert len(sg.get_n_repr()) == 1
assert len(sg.get_e_repr()) == 1
sh = sg.get_n_repr()['h']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment