"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "363699044e365ef977a7646b500402fa585e1b6b"
Commit 9827e481 authored by Mufei Li's avatar Mufei Li Committed by Gan Quan
Browse files

Support for adding nodes/edges after setting representations (#114)

* Fix

1. Fix two typos in gcn.py and gcn_spmv.py
2. Update README

* Fix GCN module

1. Update the outdated graph convolution layer class
2. Fix a bug in the code where dropout never works. Modules like dropout/batch norm depend on whether we are in the training stage or inference stage.

* Fix a bug in dropout

1. dropout depends on nn.Module.training

* Update GCN module

* Fix README

* Fix dropout & remove self.msg_field

* Fix

* Align with TF implementation

* Make g an argument for forward

* Remove features from the argument of GraphConv layer

* Support for create nodes/edges after setting representations

* Remove redundant commit

* Delete test_init_repr.py

* Test case for dynamic addition

* Base 'add_rows' upon 'append'

* Move test function

* Fix

* test by assertion

* changed add_rows to adding blank rows only; adding convert_to to backend

* moving test to basics

* oops mxnet
parent 2758c249
...@@ -120,6 +120,12 @@ def get_context(x): ...@@ -120,6 +120,12 @@ def get_context(x):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[x.context.device_type], x.context.device_id) TVMContext.STR2MASK[x.context.device_type], x.context.device_id)
def convert_to(src, dst):
'''
Convert src to the same dtype and context as dst
'''
return src.copyto(dst.context).astype(dst.dtype)
def _typestr(arr_dtype): def _typestr(arr_dtype):
return arr_dtype return arr_dtype
......
...@@ -93,6 +93,12 @@ def get_context(arr): ...@@ -93,6 +93,12 @@ def get_context(arr):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index) TVMContext.STR2MASK[arr.device.type], arr.device.index)
def convert_to(src, dst):
'''
Convert src to the same dtype and context as dst
'''
return src.to(dst)
def get_tvmtype(arr): def get_tvmtype(arr):
arr_dtype = arr.dtype arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half): if arr_dtype in (th.float16, th.half):
......
...@@ -58,6 +58,10 @@ class Column(object): ...@@ -58,6 +58,10 @@ class Column(object):
"""The column length.""" """The column length."""
return F.shape(self.data)[0] return F.shape(self.data)[0]
@property
def shape(self):
return self.scheme.shape
def __getitem__(self, idx): def __getitem__(self, idx):
"""Return the feature data given the index. """Return the feature data given the index.
...@@ -112,6 +116,23 @@ class Column(object): ...@@ -112,6 +116,23 @@ class Column(object):
else: else:
self.data = F.scatter_row(self.data, user_idx, feats) self.data = F.scatter_row(self.data, user_idx, feats)
def extend(self, feats, feat_scheme=None):
"""Extend the feature data.
Parameters
----------
feats : Tensor
The new features.
"""
if feat_scheme is None:
feat_scheme = Scheme.infer_scheme(feats)
if feat_scheme != self.scheme:
raise DGLError("Cannot update column of scheme %s using feature of scheme %s."
% (feat_scheme, self.scheme))
feats = F.convert_to(feats, self.data)
self.data = F.pack([self.data, feats])
@staticmethod @staticmethod
def create(data): def create(data):
"""Create a new column using the given data.""" """Create a new column using the given data."""
...@@ -156,6 +177,13 @@ class Frame(MutableMapping): ...@@ -156,6 +177,13 @@ class Frame(MutableMapping):
# in the first call and zero initializer will be used later. # in the first call and zero initializer will be used later.
self._initializer = None self._initializer = None
def _warn_and_set_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
# TODO(minjie): handle data type
self._initializer = lambda shape, dtype: F.zeros(shape)
def set_initializer(self, initializer): def set_initializer(self, initializer):
"""Set the initializer for empty values. """Set the initializer for empty values.
...@@ -253,11 +281,7 @@ class Frame(MutableMapping): ...@@ -253,11 +281,7 @@ class Frame(MutableMapping):
' number of rows is unknown. Make sure there is at least' ' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.' % name) ' one column in the frame so number of rows can be inferred.' % name)
if self.initializer is None: if self.initializer is None:
dgl_warning('Initializer is not set. Use zero initializer instead.' self._warn_and_set_initializer()
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
# TODO(minjie): handle data type
self.set_initializer(lambda shape, dtype : F.zeros(shape))
# TODO(minjie): directly init data on the targer device. # TODO(minjie): directly init data on the targer device.
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype) init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
init_data = F.to_context(init_data, ctx) init_data = F.to_context(init_data, ctx)
...@@ -301,13 +325,7 @@ class Frame(MutableMapping): ...@@ -301,13 +325,7 @@ class Frame(MutableMapping):
self._num_rows = other.num_rows self._num_rows = other.num_rows
else: else:
for key, col in other.items(): for key, col in other.items():
sch = self._columns[key].scheme self._columns[key].extend(col.data, col.scheme)
other_sch = col.scheme
if sch != other_sch:
raise DGLError("Cannot append column of scheme %s to column of scheme %s."
% (other_scheme, sch))
self._columns[key].data = F.pack(
[self._columns[key].data, col.data])
self._num_rows += other.num_rows self._num_rows += other.num_rows
def clear(self): def clear(self):
...@@ -553,6 +571,30 @@ class FrameRef(MutableMapping): ...@@ -553,6 +571,30 @@ class FrameRef(MutableMapping):
fcol = self._frame[name] fcol = self._frame[name]
fcol.update(self.index(), data, inplace) fcol.update(self.index(), data, inplace)
def add_rows(self, num_rows):
"""Add blank rows.
For existing fields, the rows will be extended according to their
initializers.
Parameters
----------
num_rows : int
Number of rows to add
"""
feat_placeholders = {}
for key in self._frame:
scheme = self._frame[key].scheme
if self._frame.initializer is None:
self._frame._warn_and_set_initializer()
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype)
feat_placeholders[key] = new_data
self.append(feat_placeholders)
def update_rows(self, query, data, inplace): def update_rows(self, query, data, inplace):
"""Update the rows. """Update the rows.
......
...@@ -75,6 +75,10 @@ class DGLGraph(object): ...@@ -75,6 +75,10 @@ class DGLGraph(object):
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing
if self._node_frame.num_columns > 0 and self._node_frame.num_rows > 0:
self._node_frame.add_rows(num)
def add_edge(self, u, v, reprs=None): def add_edge(self, u, v, reprs=None):
"""Add one edge. """Add one edge.
...@@ -91,6 +95,10 @@ class DGLGraph(object): ...@@ -91,6 +95,10 @@ class DGLGraph(object):
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0:
self._edge_frame.add_rows(1)
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, reprs=None):
"""Add many edges. """Add many edges.
...@@ -109,6 +117,10 @@ class DGLGraph(object): ...@@ -109,6 +117,10 @@ class DGLGraph(object):
#TODO(minjie): change frames #TODO(minjie): change frames
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0:
self._edge_frame.add_rows(len(u))
def clear(self): def clear(self):
"""Clear the graph and its storage.""" """Clear the graph and its storage."""
self._graph.clear() self._graph.clear()
......
...@@ -345,6 +345,38 @@ def test_send_multigraph(): ...@@ -345,6 +345,38 @@ def test_send_multigraph():
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5)) assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
def test_dynamic_addition():
N = 3
D = 1
g = DGLGraph()
# Test node addition
g.add_nodes(N)
g.set_n_repr({'h1': th.randn(N, D),
'h2': th.randn(N, D)})
g.add_nodes(3)
n_repr = g.get_n_repr()
assert n_repr['h1'].shape[0] == n_repr['h2'].shape[0] == N + 3
# Test edge addition
g.add_edge(0, 1)
g.add_edge(1, 0)
g.set_e_repr({'h1': th.randn(2, D),
'h2': th.randn(2, D)})
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 2
g.add_edges([0, 2], [2, 0])
e_repr = g.get_e_repr()
g.set_e_repr({'h1': th.randn(4, D)})
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 4
g.add_edge(1, 2)
g.set_e_repr_by_id({'h1': th.randn(1, D)}, eid=4)
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 5
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
...@@ -355,3 +387,4 @@ if __name__ == '__main__': ...@@ -355,3 +387,4 @@ if __name__ == '__main__':
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
test_send_multigraph() test_send_multigraph()
test_dynamic_addition()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment