"vscode:/vscode.git/clone" did not exist on "369a51c432c310a74de6e185ea072af4a398ec67"
Unverified Commit 0ec1a492 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Frame] change num rows behavior (#163)

* move initializer to column scheme; change num rows behavior

* poke mx ci

* fix mx utest in append

* fix bug in add edges

* utest for updating partial rows

* fix bug in from_networkx and from_scipy_matrix

* revert per-col initializer change

* fix pickle utest
parent 440aecee
......@@ -128,8 +128,8 @@ def unbatch(graph):
be = graph.batch_num_edges
pttns = gi.disjoint_partition(graph._graph, utils.toindex(bn))
# split the frames
node_frames = [FrameRef() for i in range(bsize)]
edge_frames = [FrameRef() for i in range(bsize)]
node_frames = [FrameRef(Frame(num_rows=n)) for n in bn]
edge_frames = [FrameRef(Frame(num_rows=n)) for n in be]
for attr, col in graph._node_frame.items():
col_splits = F.split(col, bn, dim=0)
for i in range(bsize):
......
......@@ -18,7 +18,7 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
----------
shape : tuple of int
The feature shape.
dtype : TVMType
dtype : backend-specific type object
The feature data type.
"""
# FIXME:
......@@ -29,8 +29,9 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
# https://github.com/pytorch/pytorch/issues/14057
if sys.version_info.major == 3 and sys.version_info.minor == 5:
def __reduce__(self):
return self._reconstruct_scheme, \
(self.shape, F.reverse_data_type_dict[self.dtype])
state = (self.shape, F.reverse_data_type_dict[self.dtype])
return self._reconstruct_scheme, state
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
......@@ -155,7 +156,7 @@ class Column(object):
def create(data):
"""Create a new column using the given data."""
if isinstance(data, Column):
return Column(data.data)
return Column(data.data, data.scheme)
else:
return Column(data)
......@@ -177,11 +178,14 @@ class Frame(MutableMapping):
this frame will NOT share columns with the given frame. So any out-place
update on one will not reflect to the other. The inplace update will
be seen by both. This follows the semantic of python's container.
num_rows : int, optional [default=0]
The number of rows in this frame. If ``data`` is provided, ``num_rows``
will be ignored and inferred from the given data.
"""
def __init__(self, data=None):
def __init__(self, data=None, num_rows=0):
if data is None:
self._columns = dict()
self._num_rows = 0
self._num_rows = num_rows
else:
# Note that we always create a new column for the given data.
# This avoids two frames accidentally sharing the same column.
......@@ -198,7 +202,7 @@ class Frame(MutableMapping):
# Initializer for empty values. Initializer is a callable.
# If is none, then a warning will be raised
# in the first call and zero initializer will be used later.
self._initializers = {}
self._initializers = {} # per-column initializers
self._default_initializer = None
def _warn_and_set_initializer(self):
......@@ -220,7 +224,7 @@ class Frame(MutableMapping):
callable
The initializer
"""
return self._initializers.get(column, self._default_initializer)
return self._initializers.get(column, self._default_initializer)
def set_initializer(self, initializer, column=None):
"""Set the initializer for empty values, for a given column or all future
......@@ -295,8 +299,6 @@ class Frame(MutableMapping):
The column name.
"""
del self._columns[name]
if len(self._columns) == 0:
self._num_rows = 0
def add_column(self, name, scheme, ctx):
"""Add a new column to the frame.
......@@ -315,10 +317,6 @@ class Frame(MutableMapping):
if name in self:
dgl_warning('Column "%s" already exists. Ignore adding this column again.' % name)
return
if self.num_rows == 0:
raise DGLError('Cannot add column "%s" using column schemes because'
' number of rows is unknown. Make sure there is at least'
' one column in the frame so number of rows can be inferred.' % name)
if self.get_initializer(name) is None:
self._warn_and_set_initializer()
init_data = self.get_initializer(name)(
......@@ -361,19 +359,22 @@ class Frame(MutableMapping):
The column data.
"""
col = Column.create(data)
if self.num_columns == 0:
self._num_rows = len(col)
elif len(col) != self._num_rows:
if len(col) != self.num_rows:
raise DGLError('Expected data to have %d rows, got %d.' %
(self._num_rows, len(col)))
(self.num_rows, len(col)))
self._columns[name] = col
def _append(self, other):
# NOTE: `other` can be empty.
if len(self._columns) == 0:
self._columns = {key: col for key, col in other.items()}
if self.num_rows == 0:
# if no rows in current frame; append is equivalent to
# directly updating columns.
self._columns = {key: Column.create(data) for key, data in other.items()}
else:
for key, col in other.items():
if key not in self._columns:
# the column does not exist; init a new column
self.add_column(key, col.scheme, F.context(col.data))
self._columns[key].extend(col.data, col.scheme)
def append(self, other):
......@@ -390,7 +391,6 @@ class Frame(MutableMapping):
"""
if not isinstance(other, Frame):
other = Frame(other)
self._append(other)
self._num_rows += other.num_rows
......@@ -711,7 +711,7 @@ class FrameRef(MutableMapping):
Please note that "deleted" rows are not really deleted, but simply removed
in the reference. As a result, if two FrameRefs point to the same Frame, deleting
from one ref will not relect on the other. By contrast, deleting columns is real.
from one ref will not relect on the other. However, deleting columns is real.
Parameters
----------
......@@ -720,8 +720,6 @@ class FrameRef(MutableMapping):
"""
if isinstance(key, str):
del self._frame[key]
if len(self._frame) == 0:
self.clear()
else:
self.delete_rows(key)
......
......@@ -141,12 +141,17 @@ class DGLGraph(object):
self._readonly=readonly
self._graph = create_graph_index(graph_data, multigraph, readonly)
# frame
self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
if node_frame is None:
self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
self._node_frame = node_frame
if edge_frame is None:
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# msg graph & frame
self._msg_graph = create_graph_index(multigraph=multigraph)
self._msg_frame = FrameRef()
self._msg_edges = []
self.reset_messages()
# registered functions
self._message_func = None
......@@ -154,25 +159,25 @@ class DGLGraph(object):
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
def add_nodes(self, num, data=None):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
data : dict
Optional node feature data.
"""
self._graph.add_nodes(num)
self._msg_graph.add_nodes(num)
#TODO(minjie): change frames
assert reprs is None
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
if data is None:
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
else:
self._node_frame.append(data)
def add_edge(self, u, v, reprs=None):
def add_edge(self, u, v, data=None):
"""Add one edge.
Parameters
......@@ -181,21 +186,21 @@ class DGLGraph(object):
The src node.
v : int
The dst node.
reprs : dict
Optional edge representation.
data : dict
Optional node feature data.
See Also
--------
add_edges
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
assert reprs is None
# Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(1)
if data is None:
# Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(1)
else:
self._edge_frame.append(data)
def add_edges(self, u, v, reprs=None):
def add_edges(self, u, v, data=None):
"""Add many edges.
Parameters
......@@ -214,11 +219,12 @@ class DGLGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
assert reprs is None
# Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(len(u))
if data is None:
# Initialize feature placeholders if there are features existing
# NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(max(len(u), len(v)))
else:
self._edge_frame.append(data)
def clear(self):
"""Clear the graph and its storage."""
......@@ -227,13 +233,11 @@ class DGLGraph(object):
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_edges.clear()
def reset_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_edges.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self):
......@@ -672,6 +676,8 @@ class DGLGraph(object):
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
# copy attributes
def _batcher(lst):
......@@ -705,6 +711,8 @@ class DGLGraph(object):
"""
self.clear()
self._graph.from_scipy_sparse_matrix(a)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self):
......
......@@ -23,7 +23,7 @@ def create_test_data(grad=False):
def test_create():
data = create_test_data()
f1 = Frame()
f1 = Frame(num_rows=N)
for k, v in data.items():
f1.update_column(k, v)
print(f1.schemes)
......@@ -56,12 +56,7 @@ def test_column1():
del f['a2']
assert len(f) == 1
del f['a3']
assert f.num_rows == 0
assert len(f) == 0
# add a different length column should succeed
f['a4'] = th.zeros([N+1, D])
assert f.num_rows == N+1
assert len(f) == 1
def test_column2():
# Test frameref column getter/setter
......@@ -119,6 +114,23 @@ def test_append2():
assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N
def test_append3():
# test append on empty frame
f = Frame(num_rows=5)
data = {'h' : th.ones((3, 2))}
f.append(data)
assert f.num_rows == 8
ans = th.cat([th.zeros((5, 2)), th.ones((3, 2))], dim=0)
assert U.allclose(f['h'].data, ans)
# test append with new column
data = {'h' : 2 * th.ones((3, 2)), 'w' : 2 * th.ones((3, 2))}
f.append(data)
assert f.num_rows == 11
ans1 = th.cat([ans, 2 * th.ones((3, 2))], 0)
ans2 = th.cat([th.zeros((8, 2)), 2 * th.ones((3, 2))], 0)
assert U.allclose(f['h'].data, ans1)
assert U.allclose(f['w'].data, ans2)
def test_row1():
# test row getter/setter
data = create_test_data()
......@@ -210,6 +222,15 @@ def test_row3():
for k, v in f.items():
assert U.allclose(v, data[k][newidx])
def test_row4():
# test updating row with empty frame but has preset num_rows
f = FrameRef(Frame(num_rows=5))
rowid = Index(th.tensor([0, 2, 4]))
f[rowid] = {'h' : th.ones((3, 2))}
ans = th.zeros((5, 2))
ans[th.tensor([0, 2, 4])] = th.ones((3, 2))
assert U.allclose(f['h'], ans)
def test_sharing():
data = Frame(create_test_data())
f1 = FrameRef(data, index=[0, 1, 2, 3])
......@@ -290,9 +311,11 @@ if __name__ == '__main__':
test_column2()
test_append1()
test_append2()
test_append3()
test_row1()
test_row2()
test_row3()
test_row4()
test_sharing()
test_slicing()
test_add_rows()
......@@ -4,6 +4,27 @@ import numpy as np
import scipy.sparse as sp
import torch as th
import dgl
import utils as U
def test_graph_creation():
g = dgl.DGLGraph()
# test add nodes with data
g.add_nodes(5)
g.add_nodes(5, {'h' : th.ones((5, 2))})
ans = th.cat([th.zeros(5, 2), th.ones(5, 2)], 0)
U.allclose(ans, g.ndata['h'])
g.ndata['w'] = 2 * th.ones((10, 2))
assert U.allclose(2 * th.ones((10, 2)), g.ndata['w'])
# test add edges with data
g.add_edges([2, 3], [3, 4])
g.add_edges([0, 1], [1, 2], {'m' : th.ones((2, 2))})
ans = th.cat([th.zeros(2, 2), th.ones(2, 2)], 0)
assert U.allclose(ans, g.edata['m'])
# test clear and add again
g.clear()
g.add_nodes(5)
g.ndata['h'] = 3 * th.ones((5, 2))
assert U.allclose(3 * th.ones((5, 2)), g.ndata['h'])
def test_adjmat_speed():
n = 1000
......@@ -36,5 +57,6 @@ def test_incmat_speed():
assert dur2 < dur1
if __name__ == '__main__':
test_graph_creation()
test_adjmat_speed()
test_incmat_speed()
......@@ -7,8 +7,8 @@ Capsule Network Tutorial
**Author**: Jinjing Zhou, `Jake
Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
It is perhaps a little surprising that some of the more classical models can also be described in terms of graphs,
offering a different perspective.
It is perhaps a little surprising that some of the more classical models can
also be described in terms of graphs, offering a different perspective.
This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__.
"""
#######################################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment