Unverified Commit 0ec1a492 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Frame] change num rows behavior (#163)

* move initializer to column scheme; change num rows behavior

* poke mx ci

* fix mx utest in append

* fix bug in add edges

* utest for updating partial rows

* fix bug in from_networkx and from_scipy_matrix

* revert per-col initializer change

* fix pickle utest
parent 440aecee
...@@ -128,8 +128,8 @@ def unbatch(graph): ...@@ -128,8 +128,8 @@ def unbatch(graph):
be = graph.batch_num_edges be = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn)) pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
# split the frames # split the frames
node_frames = [FrameRef() for i in range(bsize)] node_frames = [FrameRef(Frame(num_rows=n)) for n in bn]
edge_frames = [FrameRef() for i in range(bsize)] edge_frames = [FrameRef(Frame(num_rows=n)) for n in be]
for attr, col in graph._node_frame.items(): for attr, col in graph._node_frame.items():
col_splits = F.split(col, bn, dim=0) col_splits = F.split(col, bn, dim=0)
for i in range(bsize): for i in range(bsize):
......
...@@ -18,7 +18,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -18,7 +18,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
---------- ----------
shape : tuple of int shape : tuple of int
The feature shape. The feature shape.
dtype : TVMType dtype : backend-specific type object
The feature data type. The feature data type.
""" """
# FIXME: # FIXME:
...@@ -29,8 +29,9 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -29,8 +29,9 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
# https://github.com/pytorch/pytorch/issues/14057 # https://github.com/pytorch/pytorch/issues/14057
if sys.version_info.major == 3 and sys.version_info.minor == 5: if sys.version_info.major == 3 and sys.version_info.minor == 5:
def __reduce__(self): def __reduce__(self):
return self._reconstruct_scheme, \ state = (self.shape, F.reverse_data_type_dict[self.dtype])
(self.shape, F.reverse_data_type_dict[self.dtype]) return self._reconstruct_scheme, state
@classmethod @classmethod
def _reconstruct_scheme(cls, shape, dtype_str): def _reconstruct_scheme(cls, shape, dtype_str):
...@@ -155,7 +156,7 @@ class Column(object): ...@@ -155,7 +156,7 @@ class Column(object):
def create(data): def create(data):
"""Create a new column using the given data.""" """Create a new column using the given data."""
if isinstance(data, Column): if isinstance(data, Column):
return Column(data.data) return Column(data.data, data.scheme)
else: else:
return Column(data) return Column(data)
...@@ -177,11 +178,14 @@ class Frame(MutableMapping): ...@@ -177,11 +178,14 @@ class Frame(MutableMapping):
this frame will NOT share columns with the given frame. So any out-place this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container. be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows``
will be ignored and inferred from the given data.
""" """
def __init__(self, data=None): def __init__(self, data=None, num_rows=0):
if data is None: if data is None:
self._columns = dict() self._columns = dict()
self._num_rows = 0 self._num_rows = num_rows
else: else:
# Note that we always create a new column for the given data. # Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column. # This avoids two frames accidentally sharing the same column.
...@@ -198,7 +202,7 @@ class Frame(MutableMapping): ...@@ -198,7 +202,7 @@ class Frame(MutableMapping):
# Initializer for empty values. Initializer is a callable. # Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised # If is none, then a warning will be raised
# in the first call and zero initializer will be used later. # in the first call and zero initializer will be used later.
self._initializers = {} self._initializers = {} # per-column initializers
self._default_initializer = None self._default_initializer = None
def _warn_and_set_initializer(self): def _warn_and_set_initializer(self):
...@@ -295,8 +299,6 @@ class Frame(MutableMapping): ...@@ -295,8 +299,6 @@ class Frame(MutableMapping):
The column name. The column name.
""" """
del self._columns[name] del self._columns[name]
if len(self._columns) == 0:
self._num_rows = 0
def add_column(self, name, scheme, ctx): def add_column(self, name, scheme, ctx):
"""Add a new column to the frame. """Add a new column to the frame.
...@@ -315,10 +317,6 @@ class Frame(MutableMapping): ...@@ -315,10 +317,6 @@ class Frame(MutableMapping):
if name in self: if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name) dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return return
if self.num_rows == 0:
raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.' % name)
if self.get_initializer(name) is None: if self.get_initializer(name) is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
init_data = self.get_initializer(name)( init_data = self.get_initializer(name)(
...@@ -361,19 +359,22 @@ class Frame(MutableMapping): ...@@ -361,19 +359,22 @@ class Frame(MutableMapping):
The column data. The column data.
""" """
col = Column.create(data) col = Column.create(data)
if self.num_columns == 0: if len(col) != self.num_rows:
self._num_rows = len(col)
elif len(col) != self._num_rows:
raise DGLError('Expected data to have %d rows, got %d.' % raise DGLError('Expected data to have %d rows, got %d.' %
(self._num_rows, len(col))) (self.num_rows, len(col)))
self._columns[name] = col self._columns[name] = col
def _append(self, other): def _append(self, other):
# NOTE: `other` can be empty. # NOTE: `other` can be empty.
if len(self._columns) == 0: if self.num_rows == 0:
self._columns = {key: col for key, col in other.items()} # if no rows in current frame; append is equivalent to
# directly updating columns.
self._columns = {key: Column.create(data) for key, data in other.items()}
else: else:
for key, col in other.items(): for key, col in other.items():
if key not in self._columns:
# the column does not exist; init a new column
self.add_column(key, col.scheme, F.context(col.data))
self._columns[key].extend(col.data, col.scheme) self._columns[key].extend(col.data, col.scheme)
def append(self, other): def append(self, other):
...@@ -390,7 +391,6 @@ class Frame(MutableMapping): ...@@ -390,7 +391,6 @@ class Frame(MutableMapping):
""" """
if not isinstance(other, Frame): if not isinstance(other, Frame):
other = Frame(other) other = Frame(other)
self._append(other) self._append(other)
self._num_rows += other.num_rows self._num_rows += other.num_rows
...@@ -711,7 +711,7 @@ class FrameRef(MutableMapping): ...@@ -711,7 +711,7 @@ class FrameRef(MutableMapping):
Please note that "deleted" rows are not really deleted, but simply removed Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real. from one ref will not relect on the other. However, deleting columns is real.
Parameters Parameters
---------- ----------
...@@ -720,8 +720,6 @@ class FrameRef(MutableMapping): ...@@ -720,8 +720,6 @@ class FrameRef(MutableMapping):
""" """
if isinstance(key, str): if isinstance(key, str):
del self._frame[key] del self._frame[key]
if len(self._frame) == 0:
self.clear()
else: else:
self.delete_rows(key) self.delete_rows(key)
......
...@@ -141,12 +141,17 @@ class DGLGraph(object): ...@@ -141,12 +141,17 @@ class DGLGraph(object):
self._readonly=readonly self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly) self._graph = create_graph_index(graph_data, multigraph, readonly)
# frame # frame
self._node_frame = node_frame if node_frame is not None else FrameRef() if node_frame is None:
self._edge_frame = edge_frame if edge_frame is not None else FrameRef() self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
self._node_frame = node_frame
if edge_frame is None:
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# msg graph & frame # msg graph & frame
self._msg_graph = create_graph_index(multigraph=multigraph) self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self._msg_edges = []
self.reset_messages() self.reset_messages()
# registered functions # registered functions
self._message_func = None self._message_func = None
...@@ -154,25 +159,25 @@ class DGLGraph(object): ...@@ -154,25 +159,25 @@ class DGLGraph(object):
self._apply_node_func = None self._apply_node_func = None
self._apply_edge_func = None self._apply_edge_func = None
def add_nodes(self, num, reprs=None): def add_nodes(self, num, data=None):
"""Add nodes. """Add nodes.
Parameters Parameters
---------- ----------
num : int num : int
Number of nodes to be added. Number of nodes to be added.
reprs : dict data : dict
Optional node representations. Optional node feature data.
""" """
self._graph.add_nodes(num) self._graph.add_nodes(num)
self._msg_graph.add_nodes(num) self._msg_graph.add_nodes(num)
#TODO(minjie): change frames if data is None:
assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num) self._node_frame.add_rows(num)
else:
self._node_frame.append(data)
def add_edge(self, u, v, reprs=None): def add_edge(self, u, v, data=None):
"""Add one edge. """Add one edge.
Parameters Parameters
...@@ -181,21 +186,21 @@ class DGLGraph(object): ...@@ -181,21 +186,21 @@ class DGLGraph(object):
The src node. The src node.
v : int v : int
The dst node. The dst node.
reprs : dict data : dict
Optional edge representation. Optional node feature data.
See Also See Also
-------- --------
add_edges add_edges
""" """
self._graph.add_edge(u, v) self._graph.add_edge(u, v)
#TODO(minjie): change frames if data is None:
assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(1) self._edge_frame.add_rows(1)
else:
self._edge_frame.append(data)
def add_edges(self, u, v, reprs=None): def add_edges(self, u, v, data=None):
"""Add many edges. """Add many edges.
Parameters Parameters
...@@ -214,11 +219,12 @@ class DGLGraph(object): ...@@ -214,11 +219,12 @@ class DGLGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
self._graph.add_edges(u, v) self._graph.add_edges(u, v)
#TODO(minjie): change frames if data is None:
assert reprs is None
# Initialize feature placeholders if there are features existing # Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(len(u)) # NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(max(len(u), len(v)))
else:
self._edge_frame.append(data)
def clear(self): def clear(self):
"""Clear the graph and its storage.""" """Clear the graph and its storage."""
...@@ -227,13 +233,11 @@ class DGLGraph(object): ...@@ -227,13 +233,11 @@ class DGLGraph(object):
self._edge_frame.clear() self._edge_frame.clear()
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
self._msg_edges.clear()
def reset_messages(self): def reset_messages(self):
"""Clear all messages.""" """Clear all messages."""
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
self._msg_edges.clear()
self._msg_graph.add_nodes(self.number_of_nodes()) self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self): def number_of_nodes(self):
...@@ -672,6 +676,8 @@ class DGLGraph(object): ...@@ -672,6 +676,8 @@ class DGLGraph(object):
""" """
self.clear() self.clear()
self._graph.from_networkx(nx_graph) self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes # copy attributes
def _batcher(lst): def _batcher(lst):
...@@ -705,6 +711,8 @@ class DGLGraph(object): ...@@ -705,6 +711,8 @@ class DGLGraph(object):
""" """
self.clear() self.clear()
self._graph.from_scipy_sparse_matrix(a) self._graph.from_scipy_sparse_matrix(a)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
......
...@@ -23,7 +23,7 @@ def create_test_data(grad=False): ...@@ -23,7 +23,7 @@ def create_test_data(grad=False):
def test_create(): def test_create():
data = create_test_data() data = create_test_data()
f1 = Frame() f1 = Frame(num_rows=N)
for k, v in data.items(): for k, v in data.items():
f1.update_column(k, v) f1.update_column(k, v)
print(f1.schemes) print(f1.schemes)
...@@ -56,12 +56,7 @@ def test_column1(): ...@@ -56,12 +56,7 @@ def test_column1():
del f['a2'] del f['a2']
assert len(f) == 1 assert len(f) == 1
del f['a3'] del f['a3']
assert f.num_rows == 0
assert len(f) == 0 assert len(f) == 0
# add a different length column should succeed
f['a4'] = th.zeros([N+1, D])
assert f.num_rows == N+1
assert len(f) == 1
def test_column2(): def test_column2():
# Test frameref column getter/setter # Test frameref column getter/setter
...@@ -119,6 +114,23 @@ def test_append2(): ...@@ -119,6 +114,23 @@ def test_append2():
assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64)) assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_append3():
# test append on empty frame
f = Frame(num_rows=5)
data = {'h' : th.ones((3, 2))}
f.append(data)
assert f.num_rows == 8
ans = th.cat([th.zeros((5, 2)), th.ones((3, 2))], dim=0)
assert U.allclose(f['h'].data, ans)
# test append with new column
data = {'h' : 2 * th.ones((3, 2)), 'w' : 2 * th.ones((3, 2))}
f.append(data)
assert f.num_rows == 11
ans1 = th.cat([ans, 2 * th.ones((3, 2))], 0)
ans2 = th.cat([th.zeros((8, 2)), 2 * th.ones((3, 2))], 0)
assert U.allclose(f['h'].data, ans1)
assert U.allclose(f['w'].data, ans2)
def test_row1(): def test_row1():
# test row getter/setter # test row getter/setter
data = create_test_data() data = create_test_data()
...@@ -210,6 +222,15 @@ def test_row3(): ...@@ -210,6 +222,15 @@ def test_row3():
for k, v in f.items(): for k, v in f.items():
assert U.allclose(v, data[k][newidx]) assert U.allclose(v, data[k][newidx])
def test_row4():
# test updating row with empty frame but has preset num_rows
f = FrameRef(Frame(num_rows=5))
rowid = Index(th.tensor([0, 2, 4]))
f[rowid] = {'h' : th.ones((3, 2))}
ans = th.zeros((5, 2))
ans[th.tensor([0, 2, 4])] = th.ones((3, 2))
assert U.allclose(f['h'], ans)
def test_sharing(): def test_sharing():
data = Frame(create_test_data()) data = Frame(create_test_data())
f1 = FrameRef(data, index=[0, 1, 2, 3]) f1 = FrameRef(data, index=[0, 1, 2, 3])
...@@ -290,9 +311,11 @@ if __name__ == '__main__': ...@@ -290,9 +311,11 @@ if __name__ == '__main__':
test_column2() test_column2()
test_append1() test_append1()
test_append2() test_append2()
test_append3()
test_row1() test_row1()
test_row2() test_row2()
test_row3() test_row3()
test_row4()
test_sharing() test_sharing()
test_slicing() test_slicing()
test_add_rows() test_add_rows()
...@@ -4,6 +4,27 @@ import numpy as np ...@@ -4,6 +4,27 @@ import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch as th import torch as th
import dgl import dgl
import utils as U
def test_graph_creation():
g = dgl.DGLGraph()
# test add nodes with data
g.add_nodes(5)
g.add_nodes(5, {'h' : th.ones((5, 2))})
ans = th.cat([th.zeros(5, 2), th.ones(5, 2)], 0)
U.allclose(ans, g.ndata['h'])
g.ndata['w'] = 2 * th.ones((10, 2))
assert U.allclose(2 * th.ones((10, 2)), g.ndata['w'])
# test add edges with data
g.add_edges([2, 3], [3, 4])
g.add_edges([0, 1], [1, 2], {'m' : th.ones((2, 2))})
ans = th.cat([th.zeros(2, 2), th.ones(2, 2)], 0)
assert U.allclose(ans, g.edata['m'])
# test clear and add again
g.clear()
g.add_nodes(5)
g.ndata['h'] = 3 * th.ones((5, 2))
assert U.allclose(3 * th.ones((5, 2)), g.ndata['h'])
def test_adjmat_speed(): def test_adjmat_speed():
n = 1000 n = 1000
...@@ -36,5 +57,6 @@ def test_incmat_speed(): ...@@ -36,5 +57,6 @@ def test_incmat_speed():
assert dur2 < dur1 assert dur2 < dur1
if __name__ == '__main__': if __name__ == '__main__':
test_graph_creation()
test_adjmat_speed() test_adjmat_speed()
test_incmat_speed() test_incmat_speed()
...@@ -7,8 +7,8 @@ Capsule Network Tutorial ...@@ -7,8 +7,8 @@ Capsule Network Tutorial
**Author**: Jinjing Zhou, `Jake **Author**: Jinjing Zhou, `Jake
Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
It is perhaps a little surprising that some of the more classical models can also be described in terms of graphs, It is perhaps a little surprising that some of the more classical models can
offering a different perspective. also be described in terms of graphs, offering a different perspective.
This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__. This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__.
""" """
####################################################################################### #######################################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment