Commit e515f531 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[Frame] supporting partial row on first column and individual column initializers (#160)

addresses issue #156
parent cab1fdf2
......@@ -160,7 +160,7 @@ class Column(object):
return Column(data)
def _default_zero_initializer(shape, dtype, ctx):
def zero_initializer(shape, dtype, ctx):
return F.zeros(shape, dtype, ctx)
......@@ -198,16 +198,33 @@ class Frame(MutableMapping):
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
# in the first call and zero initializer will be used later.
self._initializer = None
self._initializers = {}
self._default_initializer = None
def _warn_and_set_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
self._initializer = _default_zero_initializer
self._default_initializer = zero_initializer
def set_initializer(self, initializer):
"""Set the initializer for empty values.
def get_initializer(self, column=None):
"""Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns
-------
callable
The initializer
"""
return self._initializers.get(column, self._default_initializer)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values, for a given column or all future
columns.
Initializer is a callable that returns a tensor given the shape and data type.
......@@ -215,13 +232,13 @@ class Frame(MutableMapping):
----------
initializer : callable
The initializer.
column : str, optional
The column name
"""
self._initializer = initializer
@property
def initializer(self):
"""Return the initializer of this frame."""
return self._initializer
if column is None:
self._default_initializer = initializer
else:
self._initializers[column] = initializer
@property
def schemes(self):
......@@ -302,11 +319,37 @@ class Frame(MutableMapping):
raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.' % name)
if self.initializer is None:
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
init_data = self.get_initializer(name)(
(self.num_rows,) + scheme.shape, scheme.dtype, ctx)
self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows):
"""Add blank rows to this frame.
For existing fields, the rows will be extended according to their
initializers.
Parameters
----------
num_rows : int
The number of new rows
"""
self._num_rows += num_rows
feat_placeholders = {}
for key, col in self._columns.items():
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders))
def update_column(self, name, data):
"""Add or replace the column with the given name and data.
......@@ -325,6 +368,14 @@ class Frame(MutableMapping):
(self._num_rows, len(col)))
self._columns[name] = col
def _append(self, other):
# NOTE: `other` can be empty.
if len(self._columns) == 0:
self._columns = {key: col for key, col in other.items()}
else:
for key, col in other.items():
self._columns[key].extend(col.data, col.scheme)
def append(self, other):
"""Append another frame's data into this frame.
......@@ -339,13 +390,8 @@ class Frame(MutableMapping):
"""
if not isinstance(other, Frame):
other = Frame(other)
if len(self._columns) == 0:
for key, col in other.items():
self._columns[key] = col
self._num_rows = other.num_rows
else:
for key, col in other.items():
self._columns[key].extend(col.data, col.scheme)
self._append(other)
self._num_rows += other.num_rows
def clear(self):
......@@ -416,7 +462,7 @@ class FrameRef(MutableMapping):
else:
return len(self._index_data)
def set_initializer(self, initializer):
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type.
......@@ -425,8 +471,10 @@ class FrameRef(MutableMapping):
----------
initializer : callable
The initializer.
column : str, optional
The column name
"""
self._frame.set_initializer(initializer)
self._frame.set_initializer(initializer, column=column)
def index(self):
"""Return the index object.
......@@ -605,28 +653,27 @@ class FrameRef(MutableMapping):
fcol.update(self.index_or_slice(), data, inplace)
def add_rows(self, num_rows):
"""Add blank rows.
"""Add blank rows to the underlying frame.
For existing fields, the rows will be extended according to their
initializers.
Note: only available for FrameRef that spans the whole column. The row
span will extend to new rows. Other FrameRefs referencing the same
frame will not be affected.
Parameters
----------
num_rows : int
Number of rows to add
"""
feat_placeholders = {}
for key in self._frame:
scheme = self._frame[key].scheme
ctx = F.context(self._frame[key].data)
if self._frame.initializer is None:
self._frame._warn_and_set_initializer()
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data
self.append(feat_placeholders)
if not self.is_span_whole_column():
raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows)
if self.is_contiguous():
self._index_data = slice(0, self._index_data.stop + num_rows)
else:
self._index_data.extend(range(self.num_rows, self.num_rows + num_rows))
def update_rows(self, query, data, inplace):
"""Update the rows.
......
......@@ -170,7 +170,6 @@ class DGLGraph(object):
assert reprs is None
# Initialize feature placeholders if there are features existing
if self._node_frame.num_columns > 0 and self._node_frame.num_rows > 0:
self._node_frame.add_rows(num)
def add_edge(self, u, v, reprs=None):
......@@ -194,7 +193,6 @@ class DGLGraph(object):
assert reprs is None
# Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0:
self._edge_frame.add_rows(1)
def add_edges(self, u, v, reprs=None):
......@@ -220,7 +218,6 @@ class DGLGraph(object):
assert reprs is None
# Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0:
self._edge_frame.add_rows(len(u))
def clear(self):
......
......@@ -269,6 +269,21 @@ def test_slicing():
f2_a1[0:2] = 0
assert U.allclose(f2['a1'], f2_a1)
def test_add_rows():
data = Frame()
f1 = FrameRef(data)
f1.add_rows(4)
x = th.randn(1, 4)
f1[Index(th.tensor([0]))] = {'x': x}
ans = th.cat([x, th.zeros(3, 4)])
assert U.allclose(f1['x'], ans)
f1.add_rows(4)
f1[4:8] = {'x': th.ones(4, 4), 'y': th.ones(4, 5)}
ans = th.cat([ans, th.ones(4, 4)])
assert U.allclose(f1['x'], ans)
ans = th.cat([th.zeros(4, 5), th.ones(4, 5)])
assert U.allclose(f1['y'], ans)
if __name__ == '__main__':
test_create()
test_column1()
......@@ -280,3 +295,4 @@ if __name__ == '__main__':
test_row3()
test_sharing()
test_slicing()
test_add_rows()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment