Commit e515f531 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[Frame] supporting partial row on first column and individual column initializers (#160)

addresses issue #156
parent cab1fdf2
...@@ -160,7 +160,7 @@ class Column(object): ...@@ -160,7 +160,7 @@ class Column(object):
return Column(data) return Column(data)
def _default_zero_initializer(shape, dtype, ctx): def zero_initializer(shape, dtype, ctx):
return F.zeros(shape, dtype, ctx) return F.zeros(shape, dtype, ctx)
...@@ -198,16 +198,33 @@ class Frame(MutableMapping): ...@@ -198,16 +198,33 @@ class Frame(MutableMapping):
# Initializer for empty values. Initializer is a callable. # Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised # If is none, then a warning will be raised
# in the first call and zero initializer will be used later. # in the first call and zero initializer will be used later.
self._initializer = None self._initializers = {}
self._default_initializer = None
def _warn_and_set_initializer(self): def _warn_and_set_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.' dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to' ' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.') ' explicitly specify which initializer to use.')
self._initializer = _default_zero_initializer self._default_initializer = zero_initializer
def set_initializer(self, initializer): def get_initializer(self, column=None):
"""Set the initializer for empty values. """Get the initializer for empty values for the given column.
Parameters
----------
column : str
The column
Returns
-------
callable
The initializer
"""
return self._initializers.get(column, self._default_initializer)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values, for a given column or all future
columns.
Initializer is a callable that returns a tensor given the shape and data type. Initializer is a callable that returns a tensor given the shape and data type.
...@@ -215,13 +232,13 @@ class Frame(MutableMapping): ...@@ -215,13 +232,13 @@ class Frame(MutableMapping):
---------- ----------
initializer : callable initializer : callable
The initializer. The initializer.
column : str, optional
The column name
""" """
self._initializer = initializer if column is None:
self._default_initializer = initializer
@property else:
def initializer(self): self._initializers[column] = initializer
"""Return the initializer of this frame."""
return self._initializer
@property @property
def schemes(self): def schemes(self):
...@@ -302,11 +319,37 @@ class Frame(MutableMapping): ...@@ -302,11 +319,37 @@ class Frame(MutableMapping):
raise DGLError('Cannot add column "%s" using column schemes because' raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least' ' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.' % name) ' one column in the frame so number of rows can be inferred.' % name)
if self.initializer is None: if self.get_initializer(name) is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx) init_data = self.get_initializer(name)(
(self.num_rows,) + scheme.shape, scheme.dtype, ctx)
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
def add_rows(self, num_rows):
"""Add blank rows to this frame.
For existing fields, the rows will be extended according to their
initializers.
Parameters
----------
num_rows : int
The number of new rows
"""
self._num_rows += num_rows
feat_placeholders = {}
for key, col in self._columns.items():
scheme = col.scheme
ctx = F.context(col.data)
if self.get_initializer(key) is None:
self._warn_and_set_initializer()
new_data = self.get_initializer(key)(
(num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data
self._append(Frame(feat_placeholders))
def update_column(self, name, data): def update_column(self, name, data):
"""Add or replace the column with the given name and data. """Add or replace the column with the given name and data.
...@@ -325,6 +368,14 @@ class Frame(MutableMapping): ...@@ -325,6 +368,14 @@ class Frame(MutableMapping):
(self._num_rows, len(col))) (self._num_rows, len(col)))
self._columns[name] = col self._columns[name] = col
def _append(self, other):
# NOTE: `other` can be empty.
if len(self._columns) == 0:
self._columns = {key: col for key, col in other.items()}
else:
for key, col in other.items():
self._columns[key].extend(col.data, col.scheme)
def append(self, other): def append(self, other):
"""Append another frame's data into this frame. """Append another frame's data into this frame.
...@@ -339,14 +390,9 @@ class Frame(MutableMapping): ...@@ -339,14 +390,9 @@ class Frame(MutableMapping):
""" """
if not isinstance(other, Frame): if not isinstance(other, Frame):
other = Frame(other) other = Frame(other)
if len(self._columns) == 0:
for key, col in other.items(): self._append(other)
self._columns[key] = col self._num_rows += other.num_rows
self._num_rows = other.num_rows
else:
for key, col in other.items():
self._columns[key].extend(col.data, col.scheme)
self._num_rows += other.num_rows
def clear(self): def clear(self):
"""Clear this frame. Remove all the columns.""" """Clear this frame. Remove all the columns."""
...@@ -416,7 +462,7 @@ class FrameRef(MutableMapping): ...@@ -416,7 +462,7 @@ class FrameRef(MutableMapping):
else: else:
return len(self._index_data) return len(self._index_data)
def set_initializer(self, initializer): def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values. """Set the initializer for empty values.
Initializer is a callable that returns a tensor given the shape and data type. Initializer is a callable that returns a tensor given the shape and data type.
...@@ -425,8 +471,10 @@ class FrameRef(MutableMapping): ...@@ -425,8 +471,10 @@ class FrameRef(MutableMapping):
---------- ----------
initializer : callable initializer : callable
The initializer. The initializer.
column : str, optional
The column name
""" """
self._frame.set_initializer(initializer) self._frame.set_initializer(initializer, column=column)
def index(self): def index(self):
"""Return the index object. """Return the index object.
...@@ -605,28 +653,27 @@ class FrameRef(MutableMapping): ...@@ -605,28 +653,27 @@ class FrameRef(MutableMapping):
fcol.update(self.index_or_slice(), data, inplace) fcol.update(self.index_or_slice(), data, inplace)
def add_rows(self, num_rows): def add_rows(self, num_rows):
"""Add blank rows. """Add blank rows to the underlying frame.
For existing fields, the rows will be extended according to their For existing fields, the rows will be extended according to their
initializers. initializers.
Note: only available for FrameRef that spans the whole column. The row
span will extend to new rows. Other FrameRefs referencing the same
frame will not be affected.
Parameters Parameters
---------- ----------
num_rows : int num_rows : int
Number of rows to add Number of rows to add
""" """
if not self.is_span_whole_column():
feat_placeholders = {} raise RuntimeError('FrameRef not spanning whole column.')
self._frame.add_rows(num_rows)
for key in self._frame: if self.is_contiguous():
scheme = self._frame[key].scheme self._index_data = slice(0, self._index_data.stop + num_rows)
ctx = F.context(self._frame[key].data) else:
if self._frame.initializer is None: self._index_data.extend(range(self.num_rows, self.num_rows + num_rows))
self._frame._warn_and_set_initializer()
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data
self.append(feat_placeholders)
def update_rows(self, query, data, inplace): def update_rows(self, query, data, inplace):
"""Update the rows. """Update the rows.
......
...@@ -170,8 +170,7 @@ class DGLGraph(object): ...@@ -170,8 +170,7 @@ class DGLGraph(object):
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
if self._node_frame.num_columns > 0 and self._node_frame.num_rows > 0: self._node_frame.add_rows(num)
self._node_frame.add_rows(num)
def add_edge(self, u, v, reprs=None): def add_edge(self, u, v, reprs=None):
"""Add one edge. """Add one edge.
...@@ -194,8 +193,7 @@ class DGLGraph(object): ...@@ -194,8 +193,7 @@ class DGLGraph(object):
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0: self._edge_frame.add_rows(1)
self._edge_frame.add_rows(1)
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, reprs=None):
"""Add many edges. """Add many edges.
...@@ -220,8 +218,7 @@ class DGLGraph(object): ...@@ -220,8 +218,7 @@ class DGLGraph(object):
assert reprs is None assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
if self._edge_frame.num_columns > 0 and self._edge_frame.num_rows > 0: self._edge_frame.add_rows(len(u))
self._edge_frame.add_rows(len(u))
def clear(self): def clear(self):
"""Clear the graph and its storage.""" """Clear the graph and its storage."""
......
...@@ -269,6 +269,21 @@ def test_slicing(): ...@@ -269,6 +269,21 @@ def test_slicing():
f2_a1[0:2] = 0 f2_a1[0:2] = 0
assert U.allclose(f2['a1'], f2_a1) assert U.allclose(f2['a1'], f2_a1)
def test_add_rows():
data = Frame()
f1 = FrameRef(data)
f1.add_rows(4)
x = th.randn(1, 4)
f1[Index(th.tensor([0]))] = {'x': x}
ans = th.cat([x, th.zeros(3, 4)])
assert U.allclose(f1['x'], ans)
f1.add_rows(4)
f1[4:8] = {'x': th.ones(4, 4), 'y': th.ones(4, 5)}
ans = th.cat([ans, th.ones(4, 4)])
assert U.allclose(f1['x'], ans)
ans = th.cat([th.zeros(4, 5), th.ones(4, 5)])
assert U.allclose(f1['y'], ans)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
test_column1() test_column1()
...@@ -280,3 +295,4 @@ if __name__ == '__main__': ...@@ -280,3 +295,4 @@ if __name__ == '__main__':
test_row3() test_row3()
test_sharing() test_sharing()
test_slicing() test_slicing()
test_add_rows()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment