Commit b0e02e5b authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

Update the subgraph (#73)

* update subgraph.

* update subgraph API.

* keep node embedding.
parent 2883eda6
...@@ -141,7 +141,7 @@ class FrameRef(MutableMapping): ...@@ -141,7 +141,7 @@ class FrameRef(MutableMapping):
else: else:
self.update_rows(key, val) self.update_rows(key, val)
def add_column(self, name, col): def add_column(self, name, col, inplace=False):
shp = F.shape(col) shp = F.shape(col)
if self.is_span_whole_column(): if self.is_span_whole_column():
if self.num_columns == 0: if self.num_columns == 0:
...@@ -157,18 +157,25 @@ class FrameRef(MutableMapping): ...@@ -157,18 +157,25 @@ class FrameRef(MutableMapping):
fcol = F.zeros((self._frame.num_rows,) + shp[1:]) fcol = F.zeros((self._frame.num_rows,) + shp[1:])
fcol = F.to_context(fcol, colctx) fcol = F.to_context(fcol, colctx)
idx = self.index().tousertensor(colctx) idx = self.index().tousertensor(colctx)
newfcol = F.scatter_row(fcol, idx, col) if inplace:
self._frame[name] = newfcol self._frame[name] = fcol
self._frame[name][idx] = col
else:
newfcol = F.scatter_row(fcol, idx, col)
self._frame[name] = newfcol
def update_rows(self, query, other): def update_rows(self, query, other, inplace=False):
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in other.items():
if key not in self: if key not in self:
# add new column # add new column
tmpref = FrameRef(self._frame, rowids) tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col) tmpref.add_column(key, col, inplace)
idx = rowids.tousertensor(F.get_context(self._frame[key])) idx = rowids.tousertensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col) if inplace:
self._frame[key][idx] = col
else:
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
def __delitem__(self, key): def __delitem__(self, key):
if isinstance(key, str): if isinstance(key, str):
......
...@@ -486,7 +486,7 @@ class DGLGraph(object): ...@@ -486,7 +486,7 @@ class DGLGraph(object):
""" """
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or To set multiple node representations at once, pass `u` with a tensor or
...@@ -524,9 +524,9 @@ class DGLGraph(object): ...@@ -524,9 +524,9 @@ class DGLGraph(object):
self._node_frame[__REPR__] = hu self._node_frame[__REPR__] = hu
else: else:
if utils.is_dict_like(hu): if utils.is_dict_like(hu):
self._node_frame[u] = hu self._node_frame.update_rows(u, hu, inplace=inplace)
else: else:
self._node_frame[u] = {__REPR__ : hu} self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
......
...@@ -15,28 +15,21 @@ class DGLSubGraph(DGLGraph): ...@@ -15,28 +15,21 @@ class DGLSubGraph(DGLGraph):
nodes): nodes):
super(DGLSubGraph, self).__init__() super(DGLSubGraph, self).__init__()
# relabel nodes # relabel nodes
self._node_mapping = utils.build_relabel_dict(nodes) self._parent = parent
self._parent_nid = utils.toindex(nodes) self._parent_nid = utils.toindex(nodes)
eids = [] self._graph, self._parent_eid = parent._graph.node_subgraph(self._parent_nid)
# create subgraph self.reset_messages()
for eid, (u, v) in enumerate(parent.edge_list):
if u in self._node_mapping and v in self._node_mapping:
self.add_edge(self._node_mapping[u],
self._node_mapping[v])
eids.append(eid)
self._parent_eid = utils.toindex(eids)
def copy_from(self, parent): def copy_to_parent(self, inplace=False):
self._parent._node_frame.update_rows(self._parent_nid, self._node_frame, inplace=inplace)
self._parent._edge_frame.update_rows(self._parent_eid, self._edge_frame, inplace=inplace)
def copy_from_parent(self):
"""Copy node/edge features from the parent graph. """Copy node/edge features from the parent graph.
All old features will be removed. All old features will be removed.
Parameters
----------
parent : DGLGraph
The parent graph to copy from.
""" """
if parent._node_frame.num_rows != 0: if self._parent._node_frame.num_rows != 0:
self._node_frame = FrameRef(Frame(parent._node_frame[self._parent_nid])) self._node_frame = FrameRef(Frame(self._parent._node_frame[self._parent_nid]))
if parent._edge_frame.num_rows != 0: if self._parent._edge_frame.num_rows != 0:
self._edge_frame = FrameRef(Frame(parent._edge_frame[self._parent_eid])) self._edge_frame = FrameRef(Frame(self._parent._edge_frame[self._parent_eid]))
...@@ -35,7 +35,7 @@ def test_basics(): ...@@ -35,7 +35,7 @@ def test_basics():
assert len(sg.get_n_repr()) == 0 assert len(sg.get_n_repr()) == 0
assert len(sg.get_e_repr()) == 0 assert len(sg.get_e_repr()) == 0
# the data is copied after explict copy from # the data is copied after explict copy from
sg.copy_from(g) sg.copy_from_parent()
assert len(sg.get_n_repr()) == 1 assert len(sg.get_n_repr()) == 1
assert len(sg.get_e_repr()) == 1 assert len(sg.get_e_repr()) == 1
sh = sg.get_n_repr()['h'] sh = sg.get_n_repr()['h']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment