Unverified Commit b9631912 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Frame] Refactor frame. (#85)

* refactor frame codes

* fix unit test

* fix gcn example

* minor doc/message changes

* raise errors for non-exist columns in FrameRef; sanity check when append

* fix unittest; change error msg

* Add warning for none initializer

* fix unittest

* use warnings package
parent 66261aee
...@@ -16,10 +16,10 @@ from dgl import DGLGraph ...@@ -16,10 +16,10 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(src, edge):
return src return {'m' : src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return torch.sum(msgs, 1) return {'h' : torch.sum(msgs['m'], 1)}
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
...@@ -28,10 +28,10 @@ class NodeApplyModule(nn.Module): ...@@ -28,10 +28,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -54,14 +54,14 @@ class GCN(nn.Module): ...@@ -54,14 +54,14 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) g.apply_nodes(apply_node_func=
self.g.set_n_repr(val) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr() return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module): ...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -49,14 +49,16 @@ class GCN(nn.Module): ...@@ -49,14 +49,16 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) g.apply_nodes(apply_node_func=
self.g.set_n_repr(val) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(), fn.sum(), layer) self.g.update_all(fn.copy_src(src='h', out='m'),
return self.g.pop_n_repr() fn.sum(msgs='m', out='h'),
layer)
return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -93,23 +93,24 @@ def get_context(arr): ...@@ -93,23 +93,24 @@ def get_context(arr):
return TVMContext( return TVMContext(
TVMContext.STR2MASK[arr.device.type], arr.device.index) TVMContext.STR2MASK[arr.device.type], arr.device.index)
def _typestr(arr_dtype): def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype in (th.float16, th.half): if arr_dtype in (th.float16, th.half):
return 'float16' return TVMType('float16')
elif arr_dtype in (th.float32, th.float): elif arr_dtype in (th.float32, th.float):
return 'float32' return TVMType('float32')
elif arr_dtype in (th.float64, th.double): elif arr_dtype in (th.float64, th.double):
return 'float64' return TVMType('float64')
elif arr_dtype in (th.int16, th.short): elif arr_dtype in (th.int16, th.short):
return 'int16' return TVMType('int16')
elif arr_dtype in (th.int32, th.int): elif arr_dtype in (th.int32, th.int):
return 'int32' return TVMType('int32')
elif arr_dtype in (th.int64, th.long): elif arr_dtype in (th.int64, th.long):
return 'int64' return TVMType('int64')
elif arr_dtype == th.int8: elif arr_dtype == th.int8:
return 'int8' return TVMType('int8')
elif arr_dtype == th.uint8: elif arr_dtype == th.uint8:
return 'uint8' return TVMType('uint8')
else: else:
raise RuntimeError('Unsupported data type:', arr_dtype) raise RuntimeError('Unsupported data type:', arr_dtype)
...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data): ...@@ -130,20 +131,6 @@ def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data.""" """Return a tensor that shares the numpy data."""
return th.from_numpy(np_data) return th.from_numpy(np_data)
'''
data = arr_data
assert data.is_contiguous()
arr = TVMArray()
shape = c_array(tvm_shape_index_t, tuple(data.shape))
arr.data = ctypes.cast(data.data_ptr(), ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMType(_typestr(data.dtype))
arr.ndim = len(shape)
arr.ctx = get_context(data)
return arr
'''
def nonzero_1d(arr): def nonzero_1d(arr):
"""Return a 1D tensor with nonzero element indices in a 1D vector""" """Return a 1D tensor with nonzero element indices in a 1D vector"""
assert arr.dim() == 1 assert arr.dim() == 1
......
"""Module for base types and utilities.""" """Module for base types and utilities."""
from __future__ import absolute_import
import warnings
from ._ffi.base import DGLError
# A special argument for selecting all nodes/edges. # A special argument for selecting all nodes/edges.
ALL = "__ALL__" ALL = "__ALL__"
...@@ -8,3 +13,5 @@ def is_all(arg): ...@@ -8,3 +13,5 @@ def is_all(arg):
__MSG__ = "__MSG__" __MSG__ = "__MSG__"
__REPR__ = "__REPR__" __REPR__ = "__REPR__"
dgl_warning = warnings.warn
This diff is collapsed.
...@@ -504,25 +504,49 @@ class DGLGraph(object): ...@@ -504,25 +504,49 @@ class DGLGraph(object):
self._msg_graph.add_nodes(self._graph.number_of_nodes()) self._msg_graph.add_nodes(self._graph.number_of_nodes())
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes. """Return the node feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of node feature columns.
""" """
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge attribute schemes. """Return the edge feature schemes.
Returns Returns
------- -------
iterable dict of str to schemes
The set of attribute names The schemes of edge feature columns.
""" """
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._node_frame.set_initializer(initializer)
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Parameters
----------
initializer : callable
The initializer.
"""
self._edge_frame.set_initializer(initializer)
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
...@@ -534,12 +558,17 @@ class DGLGraph(object): ...@@ -534,12 +558,17 @@ class DGLGraph(object):
Dictionary type is also supported for `hu`. In this case, each item Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes. will be treated as separate attribute of the nodes.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
hu : tensor or dict of tensor hu : tensor or dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if is_all(u): if is_all(u):
...@@ -607,7 +636,7 @@ class DGLGraph(object): ...@@ -607,7 +636,7 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL): def set_e_repr(self, h_uv, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or To set multiple edge representations at once, pass `u` and `v` with tensors or
...@@ -618,6 +647,9 @@ class DGLGraph(object): ...@@ -618,6 +647,9 @@ class DGLGraph(object):
Dictionary type is also supported for `h_uv`. In this case, each item Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges. will be treated as separate attribute of the edges.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor h_uv : tensor or dict of tensor
...@@ -626,28 +658,35 @@ class DGLGraph(object): ...@@ -626,28 +658,35 @@ class DGLGraph(object):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
self.set_e_repr_by_id(h_uv, eid=ALL) self.set_e_repr_by_id(h_uv, eid=ALL, inplace=inplace)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid) self.set_e_repr_by_id(h_uv, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, h_uv, eid=ALL): def set_e_repr_by_id(self, h_uv, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor h_uv : tensor or dict of tensor
Edge representation. Edge representation.
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
inplace : bool
True if the update is done inplacely
""" """
# sanity check # sanity check
if is_all(eid): if is_all(eid):
...@@ -662,16 +701,18 @@ class DGLGraph(object): ...@@ -662,16 +701,18 @@ class DGLGraph(object):
assert F.shape(h_uv)[0] == num_edges assert F.shape(h_uv)[0] == num_edges
# set # set
if is_all(eid): if is_all(eid):
# update column
if utils.is_dict_like(h_uv): if utils.is_dict_like(h_uv):
for key, val in h_uv.items(): for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else: else:
self._edge_frame[__REPR__] = h_uv self._edge_frame[__REPR__] = h_uv
else: else:
# update row
if utils.is_dict_like(h_uv): if utils.is_dict_like(h_uv):
self._edge_frame[eid] = h_uv self._edge_frame.update_rows(eid, h_uv, inplace=inplace)
else: else:
self._edge_frame[eid] = {__REPR__ : h_uv} self._edge_frame.update_rows(eid, {__REPR__ : h_uv}, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -793,12 +834,12 @@ class DGLGraph(object): ...@@ -793,12 +834,12 @@ class DGLGraph(object):
""" """
self._apply_edge_func = apply_edge_func self._apply_edge_func = apply_edge_func
def apply_nodes(self, v, apply_node_func="default"): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor v : int, iterable of int, tensor, optional
The node id(s). The node id(s).
apply_node_func : callable apply_node_func : callable
The apply node function. The apply node function.
...@@ -952,8 +993,8 @@ class DGLGraph(object): ...@@ -952,8 +993,8 @@ class DGLGraph(object):
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs} for k in msgs},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
...@@ -961,14 +1002,13 @@ class DGLGraph(object): ...@@ -961,14 +1002,13 @@ class DGLGraph(object):
self._msg_graph.add_edges(new_u, new_v) self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append( self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor()) {k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs} for k in msgs})
)
else: else:
if len(msg_target_rows) > 0: if len(msg_target_rows) > 0:
self._msg_frame.update_rows( self._msg_frame.update_rows(
msg_target_rows, msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())} {__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
) inplace=False)
if len(msg_append_rows) > 0: if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv) new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u) new_u = utils.toindex(new_u)
......
...@@ -20,22 +20,26 @@ def reduce_func(node, msgs): ...@@ -20,22 +20,26 @@ def reduce_func(node, msgs):
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)} return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(node):
return {'h' : node['h'] + node['m']} return {'h' : node['h'] + node['accum']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol}) g.set_n_repr({'h' : ncol})
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -46,8 +50,9 @@ def test_batch_setter_getter(): ...@@ -46,8 +50,9 @@ def test_batch_setter_getter():
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes # pop nodes
old_len = len(g.get_n_repr())
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0 assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
# set partial nodes # set partial nodes
u = th.tensor([1, 3, 5]) u = th.tensor([1, 3, 5])
...@@ -81,8 +86,9 @@ def test_batch_setter_getter(): ...@@ -81,8 +86,9 @@ def test_batch_setter_getter():
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# pop edges # pop edges
old_len = len(g.get_e_repr())
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0 assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
# set partial edges (many-many) # set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9]) u = th.tensor([0, 0, 2, 5, 9])
......
...@@ -30,8 +30,10 @@ def generate_graph(grad=False): ...@@ -30,8 +30,10 @@ def generate_graph(grad=False):
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col) ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr(ncol)
g.set_e_repr(ecol)
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
......
...@@ -2,14 +2,11 @@ import torch as th ...@@ -2,14 +2,11 @@ import torch as th
from torch.autograd import Variable from torch.autograd import Variable
import numpy as np import numpy as np
from dgl.frame import Frame, FrameRef from dgl.frame import Frame, FrameRef
from dgl.utils import Index from dgl.utils import Index, toindex
N = 10 N = 10
D = 5 D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def check_fail(fn): def check_fail(fn):
try: try:
fn() fn()
...@@ -27,12 +24,13 @@ def test_create(): ...@@ -27,12 +24,13 @@ def test_create():
data = create_test_data() data = create_test_data()
f1 = Frame() f1 = Frame()
for k, v in data.items(): for k, v in data.items():
f1.add_column(k, v) f1.update_column(k, v)
assert f1.schemes == set(data.keys()) print(f1.schemes)
assert f1.keys() == set(data.keys())
assert f1.num_columns == 3 assert f1.num_columns == 3
assert f1.num_rows == N assert f1.num_rows == N
f2 = Frame(data) f2 = Frame(data)
assert f2.schemes == set(data.keys()) assert f2.keys() == set(data.keys())
assert f2.num_columns == 3 assert f2.num_columns == 3
assert f2.num_rows == N assert f2.num_rows == N
f1.clear() f1.clear()
...@@ -45,9 +43,9 @@ def test_column1(): ...@@ -45,9 +43,9 @@ def test_column1():
f = Frame(data) f = Frame(data)
assert f.num_rows == N assert f.num_rows == N
assert len(f) == 3 assert len(f) == 3
assert check_eq(f['a1'], data['a1']) assert th.allclose(f['a1'].data, data['a1'].data)
f['a1'] = data['a2'] f['a1'] = data['a2']
assert check_eq(f['a2'], data['a2']) assert th.allclose(f['a2'].data, data['a2'].data)
# add a different length column should fail # add a different length column should fail
def failed_add_col(): def failed_add_col():
f['a4'] = th.zeros([N+1, D]) f['a4'] = th.zeros([N+1, D])
...@@ -70,16 +68,15 @@ def test_column2(): ...@@ -70,16 +68,15 @@ def test_column2():
f = FrameRef(data, [3, 4, 5, 6, 7]) f = FrameRef(data, [3, 4, 5, 6, 7])
assert f.num_rows == 5 assert f.num_rows == 5
assert len(f) == 3 assert len(f) == 3
assert check_eq(f['a1'], data['a1'][3:8]) assert th.allclose(f['a1'], data['a1'].data[3:8])
# set column should reflect on the referenced data # set column should reflect on the referenced data
f['a1'] = th.zeros([5, D]) f['a1'] = th.zeros([5, D])
assert check_eq(data['a1'][3:8], th.zeros([5, D])) assert th.allclose(data['a1'].data[3:8], th.zeros([5, D]))
# add new column should be padded with zero # add new partial column should fail with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
def failed_add_col():
f['a4'] = th.ones([5, D]) f['a4'] = th.ones([5, D])
assert len(data) == 4 assert check_fail(failed_add_col)
assert check_eq(data['a4'][0:3], th.zeros([3, D]))
assert check_eq(data['a4'][3:8], th.ones([5, D]))
assert check_eq(data['a4'][8:10], th.zeros([2, D]))
def test_append1(): def test_append1():
# test append API on Frame # test append API on Frame
...@@ -91,9 +88,14 @@ def test_append1(): ...@@ -91,9 +88,14 @@ def test_append1():
f1.append(f2) f1.append(f2)
assert f1.num_rows == 2 * N assert f1.num_rows == 2 * N
c1 = f1['a1'] c1 = f1['a1']
assert c1.shape == (2 * N, D) assert c1.data.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']]) truth = th.cat([data['a1'], data['a1']])
assert check_eq(truth, c1) assert th.allclose(truth, c1.data)
# append dict of different length columns should fail
f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))}
def failed_append():
f1.append(f3)
assert check_fail(failed_append)
def test_append2(): def test_append2():
# test append on FrameRef # test append on FrameRef
...@@ -113,7 +115,7 @@ def test_append2(): ...@@ -113,7 +115,7 @@ def test_append2():
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
assert f.num_rows == 3 * N assert f.num_rows == 3 * N
new_idx = list(range(N)) + list(range(2*N, 4*N)) new_idx = list(range(N)) + list(range(2*N, 4*N))
assert check_eq(f.index().tousertensor(), th.tensor(new_idx)) assert th.all(f.index().tousertensor() == th.tensor(new_idx, dtype=th.int64))
assert data.num_rows == 4 * N assert data.num_rows == 4 * N
def test_row1(): def test_row1():
...@@ -127,13 +129,13 @@ def test_row1(): ...@@ -127,13 +129,13 @@ def test_row1():
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert th.allclose(v, data[k][rowid])
# test duplicate keys # test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1])) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
for k, v in rows.items(): for k, v in rows.items():
assert v.shape == (len(rowid), D) assert v.shape == (len(rowid), D)
assert check_eq(v, data[k][rowid]) assert th.allclose(v, data[k][rowid])
# setter # setter
rowid = Index(th.tensor([0, 2, 4])) rowid = Index(th.tensor([0, 2, 4]))
...@@ -143,12 +145,14 @@ def test_row1(): ...@@ -143,12 +145,14 @@ def test_row1():
} }
f[rowid] = vals f[rowid] = vals
for k, v in f[rowid].items(): for k, v in f[rowid].items():
assert check_eq(v, th.zeros((len(rowid), D))) assert th.allclose(v, th.zeros((len(rowid), D)))
# setting rows with new column should automatically add a new column # setting rows with new column should raise error with error initializer
f.set_initializer(lambda shape, dtype : assert_(False))
def failed_update_rows():
vals['a4'] = th.ones((len(rowid), D)) vals['a4'] = th.ones((len(rowid), D))
f[rowid] = vals f[rowid] = vals
assert len(f) == 4 assert check_fail(failed_update_rows)
def test_row2(): def test_row2():
# test row getter/setter autograd compatibility # test row getter/setter autograd compatibility
...@@ -161,13 +165,13 @@ def test_row2(): ...@@ -161,13 +165,13 @@ def test_row2():
rowid = Index(th.tensor([0, 2])) rowid = Index(th.tensor([0, 2]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])) assert th.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# test duplicate keys # test duplicate keys
rowid = Index(th.tensor([8, 2, 2, 1])) rowid = Index(th.tensor([8, 2, 2, 1]))
rows = f[rowid] rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D))) rows['a1'].backward(th.ones((len(rowid), D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.])) assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_() c1.grad.data.zero_()
# setter # setter
...@@ -180,8 +184,8 @@ def test_row2(): ...@@ -180,8 +184,8 @@ def test_row2():
f[rowid] = vals f[rowid] = vals
c11 = f['a1'] c11 = f['a1']
c11.backward(th.ones((N, D))) c11.backward(th.ones((N, D)))
assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])) assert th.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
assert check_eq(vals['a1'].grad, th.ones((len(rowid), D))) assert th.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None assert vals['a2'].grad is None
def test_row3(): def test_row3():
...@@ -201,8 +205,9 @@ def test_row3(): ...@@ -201,8 +205,9 @@ def test_row3():
newidx = list(range(N)) newidx = list(range(N))
newidx.pop(2) newidx.pop(2)
newidx.pop(2) newidx.pop(2)
newidx = toindex(newidx)
for k, v in f.items(): for k, v in f.items():
assert check_eq(v, data[k][th.tensor(newidx)]) assert th.allclose(v, data[k][newidx])
def test_sharing(): def test_sharing():
data = Frame(create_test_data()) data = Frame(create_test_data())
...@@ -210,10 +215,10 @@ def test_sharing(): ...@@ -210,10 +215,10 @@ def test_sharing():
f2 = FrameRef(data, index=[2, 3, 4, 5, 6]) f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
# test read # test read
for k, v in f1.items(): for k, v in f1.items():
assert check_eq(data[k][0:4], v) assert th.allclose(data[k].data[0:4], v)
for k, v in f2.items(): for k, v in f2.items():
assert check_eq(data[k][2:7], v) assert th.allclose(data[k].data[2:7], v)
f2_a1 = f2['a1'] f2_a1 = f2['a1'].data
# test write # test write
# update own ref should not been seen by the other. # update own ref should not been seen by the other.
f1[Index(th.tensor([0, 1]))] = { f1[Index(th.tensor([0, 1]))] = {
...@@ -221,7 +226,7 @@ def test_sharing(): ...@@ -221,7 +226,7 @@ def test_sharing():
'a2' : th.zeros([2, D]), 'a2' : th.zeros([2, D]),
'a3' : th.zeros([2, D]), 'a3' : th.zeros([2, D]),
} }
assert check_eq(f2['a1'], f2_a1) assert th.allclose(f2['a1'], f2_a1)
# update shared space should been seen by the other. # update shared space should been seen by the other.
f1[Index(th.tensor([2, 3]))] = { f1[Index(th.tensor([2, 3]))] = {
'a1' : th.ones([2, D]), 'a1' : th.ones([2, D]),
...@@ -229,7 +234,7 @@ def test_sharing(): ...@@ -229,7 +234,7 @@ def test_sharing():
'a3' : th.ones([2, D]), 'a3' : th.ones([2, D]),
} }
f2_a1[0:2] = th.ones([2, D]) f2_a1[0:2] = th.ones([2, D])
assert check_eq(f2['a1'], f2_a1) assert th.allclose(f2['a1'], f2_a1)
if __name__ == '__main__': if __name__ == '__main__':
test_create() test_create()
......
...@@ -123,6 +123,7 @@ def test_update_all_multi_fn(): ...@@ -123,6 +123,7 @@ def test_update_all_multi_fn():
return {'v2': th.sum(msgs['m2'], 1)} return {'v2': th.sum(msgs['m2'], 1)}
g = generate_graph() g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10,)), 'v2' : th.zeros((10,))})
fld = 'f2' fld = 'f2'
# update all, mix of builtin and UDF # update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func], g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
...@@ -173,6 +174,8 @@ def test_send_and_recv_multi_fn(): ...@@ -173,6 +174,8 @@ def test_send_and_recv_multi_fn():
return {'v2' : th.sum(msgs['m2'], 1)} return {'v2' : th.sum(msgs['m2'], 1)}
g = generate_graph() g = generate_graph()
g.set_n_repr({'v1' : th.zeros((10, D)), 'v2' : th.zeros((10, D)),
'v3' : th.zeros((10, D))})
fld = 'f2' fld = 'f2'
# send and recv, mix of builtin and UDF # send and recv, mix of builtin and UDF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment