Unverified Commit b9631912 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Frame] Refactor frame. (#85)

* refactor frame codes

* fix unit test

* fix gcn example

* minor doc/message changes

* raise errors for non-exist columns in FrameRef; sanity check when append

* fix unittest; change error msg

* Add warning for none initializer

* fix unittest

* use warnings package
parent 66261aee
......@@ -16,10 +16,10 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src
return {'m' : src['h']}
def gcn_reduce(node, msgs):
return torch.sum(msgs, 1)
return {'h' : torch.sum(msgs['m'], 1)}
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
......@@ -28,10 +28,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation
def forward(self, node):
h = self.linear(node)
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return h
return {'h' : h}
class GCN(nn.Module):
def __init__(self,
......@@ -54,14 +54,14 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
self.g.set_n_repr({'h' : features})
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr()
return self.g.pop_n_repr('h')
def main(args):
# load and preprocess dataset
......
......@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation
def forward(self, node):
h = self.linear(node)
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return h
return {'h' : h}
class GCN(nn.Module):
def __init__(self,
......@@ -49,14 +49,16 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
self.g.set_n_repr({'h' : features})
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(fn.copy_src(), fn.sum(), layer)
return self.g.pop_n_repr()
g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msgs='m', out='h'),
layer)
return self.g.pop_n_repr('h')
def main(args):
# load and preprocess dataset
......
......@@ -93,23 +93,24 @@ def get_context(arr):
return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype):
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half):
return 'float16'
return TVMType('float16')
elif arr_dtype in (th.float32, th.float):
return 'float32'
return TVMType('float32')
elif arr_dtype in (th.float64, th.double):
return 'float64'
return TVMType('float64')
elif arr_dtype in (th.int16, th.short):
return 'int16'
return TVMType('int16')
elif arr_dtype in (th.int32, th.int):
return 'int32'
return TVMType('int32')
elif arr_dtype in (th.int64, th.long):
return 'int64'
return TVMType('int64')
elif arr_dtype == th.int8:
return 'int8'
return TVMType('int8')
elif arr_dtype == th.uint8:
return 'uint8'
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
......@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1
......
"""Module for base types and utilities."""
from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
# A special argument for selecting all nodes/edges.
ALL = "__ALL__"
......@@ -8,3 +13,5 @@ def is_all(arg):
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
dgl_warning = warnings.warn
This diff is collapsed.
......@@ -504,25 +504,49 @@ class DGLGraph(object):
self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self):
"""Return the node attribute schemes.
"""Return the node feature schemes.
Returns
-------
iterable
The set of attribute names
dict of str to schemes
The schemes of node feature columns.
"""
return self._node_frame.schemes
def edge_attr_schemes(self):
"""Return the edge attribute schemes.
"""Return the edge feature schemes.
Returns
-------
iterable
The set of attribute names
dict of str to schemes
The schemes of edge feature columns.
"""
return self._edge_frame.schemes
def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._node_frame.set_initializer(initializer)
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._edge_frame.set_initializer(initializer)
def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation.
......@@ -534,12 +558,17 @@ class DGLGraph(object):
Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
hu : tensor or dict of tensor
Node representation.
u : node, container or tensor
The node(s).
inplace : bool
True if the update is done inplacely
"""
# sanity check
if is_all(u):
......@@ -607,7 +636,7 @@ class DGLGraph(object):
"""
return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL):
def set_e_repr(self, h_uv, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or
......@@ -618,6 +647,9 @@ class DGLGraph(object):
Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
h_uv : tensor or dict of tensor
......@@ -626,28 +658,35 @@ class DGLGraph(object):
The source node(s).
v : node, container or tensor
The destination node(s).
inplace : bool
True if the update is done inplacely
"""
# sanity check
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
self.set_e_repr_by_id(h_uv, eid=ALL)
self.set_e_repr_by_id(h_uv, eid=ALL, inplace=inplace)
else:
u = utils.toindex(u)
v = utils.toindex(v)
_, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid)
self.set_e_repr_by_id(h_uv, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, h_uv, eid=ALL):
def set_e_repr_by_id(self, h_uv, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
h_uv : tensor or dict of tensor
Edge representation.
eid : int, container or tensor
The edge id(s).
inplace : bool
True if the update is done inplacely
"""
# sanity check
if is_all(eid):
......@@ -662,16 +701,18 @@ class DGLGraph(object):
assert F.shape(h_uv)[0] == num_edges
# set
if is_all(eid):
# update column
if utils.is_dict_like(h_uv):
for key, val in h_uv.items():
self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else:
# update row
if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv
self._edge_frame.update_rows(eid, h_uv, inplace=inplace)
else:
self._edge_frame[eid] = {__REPR__ : h_uv}
self._edge_frame.update_rows(eid, {__REPR__ : h_uv}, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation.
......@@ -793,12 +834,12 @@ class DGLGraph(object):
"""
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default"):
def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations.
Parameters
----------
v : int, iterable of int, tensor
v : int, iterable of int, tensor, optional
The node id(s).
apply_node_func : callable
The apply node function.
......@@ -952,8 +993,8 @@ class DGLGraph(object):
self._msg_frame.update_rows(
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs}
)
for k in msgs},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
......@@ -961,14 +1002,13 @@ class DGLGraph(object):
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs}
)
for k in msgs})
else:
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())}
)
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
......
......@@ -20,22 +20,26 @@ def reduce_func(node, msgs):
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)}
return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['m']}
return {'h' : node['h'] + node['accum']}
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
# 17 edges
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
return g
def test_batch_setter_getter():
......@@ -46,8 +50,9 @@ def test_batch_setter_getter():
g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes
old_len = len(g.get_n_repr())
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0
assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))})
# set partial nodes
u = th.tensor([1, 3, 5])
......@@ -81,8 +86,9 @@ def test_batch_setter_getter():
g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# pop edges
old_len = len(g.get_e_repr())
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0
assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))})
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
......
......@@ -30,8 +30,10 @@ def generate_graph(grad=False):
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col)
ncol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr(ncol)
g.set_e_repr(ecol)
return g
def test_batch_setter_getter():
......
......@@ -2,14 +2,11 @@ import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.frame import Frame, FrameRef
from dgl.utils import Index
from dgl.utils import Index, toindex
N = 10
D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def check_fail(fn):
try:
fn()
......@@ -27,12 +24,13 @@ def test_create():
data = create_test_data()
f1 = Frame()
for k, v in data.items():
f1.add_column(k, v)
assert f1.schemes == set(data.keys())
f1.update_column(k, v)
print(f1.schemes)
assert f1.keys() == set(data.keys())
assert f1.num_columns == 3
assert f1.num_rows == N
f2 = Frame(data)
assert f2.schemes == set(data.keys())
assert f2.keys() == set(data.keys())
assert f2.num_columns == 3
assert f2.num_rows == N
f1.clear()
......@@ -45,9 +43,9 @@ def test_column1():
f = Frame(data)
assert f.num_rows == N
assert len(f) == 3
assert check_eq(f['a1'], data['a1'])
assert th.allclose(f['a1'].data, data['a1'].data)
f['a1'] = data['a2']
assert check_eq(f['a2'], data['a2'])
assert th.allclose(f['a2'].data, data['a2'].data)
# add a different length column should fail
def failed_add_col():
f['a4'] = th.zeros([N+1, D])
......@@ -70,16 +68,15 @@ def test_column2():
f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5
assert len(f) == 3
assert check_eq(f['a1'], data['a1'][3:8])
assert th.allclose(f['a1'], data['a1'].data[3:8])
# set column should reflect on the referenced data
f['a1'] = th.zeros([5, D])
assert check_eq(data['a1'][3:8], th.zeros([5, D]))
# add new column should be padded with zero
assert th.allclose(data['a1'].data[3:8], th.zeros([5, D]))
# add new partial column should fail with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
def failed_add_col():
f['a4'] = th.ones([5, D])
assert len(data) == 4
assert check_eq(data['a4'][0:3], th.zeros([3, D]))
assert check_eq(data['a4'][3:8], th.ones([5, D]))
assert check_eq(data['a4'][8:10], th.zeros([2, D]))
assert check_fail(failed_add_col)
def test_append1():
# test append API on Frame
......@@ -91,9 +88,14 @@ def test_append1():
f1.append(f2)
assert f1.num_rows == 2 * N
c1 = f1['a1']
assert c1.shape == (2 * N, D)
assert c1.data.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
assert check_eq(truth, c1)
assert th.allclose(truth, c1.data)
# append dict of different length columns should fail
f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))}
def failed_append():
f1.append(f3)
assert check_fail(failed_append)
def test_append2():
# test append on FrameRef
......@@ -113,7 +115,7 @@ def test_append2():
assert not f.is_span_whole_column()
assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index().tousertensor(), th.tensor(new_idx))
assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N
def test_row1():
......@@ -127,13 +129,13 @@ def test_row1():
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid])
assert th.allclose(v, data[k][rowid])
# test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
for k, v in rows.items():
assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid])
assert th.allclose(v, data[k][rowid])
# setter
rowid = Index(th.tensor([0, 2, 4]))
......@@ -143,12 +145,14 @@ def test_row1():
}
f[rowid] = vals
for k, v in f[rowid].items():
assert check_eq(v, th.zeros((len(rowid), D)))
assert th.allclose(v, th.zeros((len(rowid), D)))
# setting rows with new column should automatically add a new column
# setting rows with new column should raise error with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
def failed_update_rows():
vals['a4'] = th.ones((len(rowid), D))
f[rowid] = vals
assert len(f) == 4
assert check_fail(failed_update_rows)
def test_row2():
# test row getter/setter autograd compatibility
......@@ -161,13 +165,13 @@ def test_row2():
rowid = Index(th.tensor([0, 2]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
assert th.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_()
# test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_()
# setter
......@@ -180,8 +184,8 @@ def test_row2():
f[rowid] = vals
c11 = f['a1']
c11.backward(th.ones((N, D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert th.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None
def test_row3():
......@@ -201,8 +205,9 @@ def test_row3():
newidx = list(range(N))
newidx.pop(2)
newidx.pop(2)
newidx = toindex(newidx)
for k, v in f.items():
assert check_eq(v, data[k][th.tensor(newidx)])
assert th.allclose(v, data[k][newidx])
def test_sharing():
data = Frame(create_test_data())
......@@ -210,10 +215,10 @@ def test_sharing():
f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read
for k, v in f1.items():
assert check_eq(data[k][0:4], v)
assert th.allclose(data[k].data[0:4], v)
for k, v in f2.items():
assert check_eq(data[k][2:7], v)
f2_a1 = f2['a1']
assert th.allclose(data[k].data[2:7], v)
f2_a1 = f2['a1'].data
# test write
# update own ref should not been seen by the other.
f1[Index(th.tensor([0, 1]))] = {
......@@ -221,7 +226,7 @@ def test_sharing():
'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]),
}
assert check_eq(f2['a1'], f2_a1)
assert th.allclose(f2['a1'], f2_a1)
# update shared space should been seen by the other.
f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]),
......@@ -229,7 +234,7 @@ def test_sharing():
'a3' : th.ones([2, D]),
}
f2_a1[0:2] = th.ones([2, D])
assert check_eq(f2['a1'], f2_a1)
assert th.allclose(f2['a1'], f2_a1)
if __name__ == '__main__':
test_create()
......
......@@ -123,6 +123,7 @@ def test_update_all_multi_fn():
return {'v2': th.sum(msgs['m2'], 1)}
g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))})
fld = 'f2'
# update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
......@@ -173,6 +174,8 @@ def test_send_and_recv_multi_fn():
return {'v2' : th.sum(msgs['m2'], 1)}
g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
'v3' : th.zeros((10, D))})
fld = 'f2'
# send and recv, mix of builtin and UDF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment