Unverified Commit c50b90cf authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Change `recvfrom` API; fix corner cases in issue20 (#25)

* Change ; fix corner cases in issue20

* Change recvfrom to recv
parent d24ccfdd
import networkx as nx
from networkx.classes.digraph import DiGraph
if __name__ == '__main__':
from torch.autograd import Variable as Var
th.random.manual_seed(0)
print("testing vanilla RNN update")
g_path = mx_Graph(nx.path_graph(2))
g_path.set_repr(0, th.rand(2, 128))
g_path.sendto(0, 1)
g_path.recvfrom(1, [0])
g_path.readout()
'''
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print("testing GRU update")
g = mx_Graph(nx.path_graph(3))
update_net = DefaultUpdateModule(h_dims=4, net_type='gru')
g.register_update_func(update_net)
msg_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
g.register_message_func(msg_net)
for n in g:
g.set_repr(n, th.rand(2, 4))
y_pre = g.readout()
g.update_from(0)
y_after = g.readout()
upd_nets = DefaultUpdateModule(h_dims=4, net_type='gru', n_func=2)
g.register_update_func(upd_nets)
g.update_from(0)
g.update_from(0)
import networkx as nx
from mx import mx_Graph
from glimpse import create_glimpse
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as MODELS
import torch.nn.init as INIT
from util import USE_CUDA, cuda
import numpy as np
import skorch
from viz import VisdomWindowManager
import matplotlib.pyplot as plt
batch_size = 32
wm = VisdomWindowManager(port=10248)
def dfs_walk(tree, curr, l):
if len(tree.succ[curr]) == 0:
return
else:
for n in tree.succ[curr]:
l.append((curr, n))
dfs_walk(tree, n, l)
l.append((n, curr))
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = nn.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform_(module.weight)
INIT.constant_(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(nn.LeakyReLU())
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def build_resnet_cnn(**config):
n_layers = config['n_layers']
final_pool_size = config['final_pool_size']
resnet = MODELS.resnet18(pretrained=False)
cnn_list = list(resnet.children())[0:n_layers]
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def init_canvas(n_nodes):
fig, ax = plt.subplots(2, 4)
fig.set_size_inches(16, 8)
return fig, ax
def display_image(fig, ax, i, im, title):
im = im.detach().cpu().numpy().transpose(1, 2, 0)
ax[i // 4, i % 4].imshow(im, cmap='gray', vmin=0, vmax=1)
ax[i // 4, i % 4].set_title(title)
class MessageModule(nn.Module):
def forward(self, state):
h, b_next = [state[k] for k in ['h', 'b_next']]
return h, b_next
class UpdateModule(nn.Module):
"""
UpdateModule:
Returns:
h: new state
b: new bounding box
a: attention (for readout)
y: prediction
"""
def __init__(self, **config):
#h_dims=128,
#n_classes=10,
#steps=5,
#filters=[16, 32, 64, 128, 256],
#kernel_size=(3, 3),
#final_pool_size=(2, 2),
#glimpse_type='gaussian',
#glimpse_size=(15, 15),
#cnn='resnet'
#):
super(UpdateModule, self).__init__()
glimpse_type = config['glimpse_type']
glimpse_size = config['glimpse_size']
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
h_dims = config['h_dims']
n_classes = config['n_classes']
self.net_b = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, self.glimpse.att_params),
)
self.net_y = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, n_classes),
)
self.net_a = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, 1),
)
self.h_to_h = nn.GRUCell(h_dims * 2, h_dims)
INIT.orthogonal_(self.h_to_h.weight_hh)
cnn = config['cnn']
final_pool_size = config['final_pool_size']
if cnn == 'resnet':
n_layers = config['n_layers']
self.cnn = build_resnet_cnn(
n_layers=n_layers,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(128 * np.prod(final_pool_size), h_dims)
else:
filters = config['filters']
kernel_size = config['kernel_size']
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(filters[-1] * np.prod(final_pool_size), h_dims)
self.max_recur = config.get('max_recur', 1)
self.h_dims = h_dims
def set_image(self, x):
self.x = x
def forward(self, node_state, message):
h, b, y, b_fix = [node_state[k] for k in ['h', 'b', 'y', 'b_fix']]
batch_size = h.shape[0]
if len(message) == 0:
h_m_avg = h.new(batch_size, self.h_dims).zero_()
else:
h_m, b_next = zip(*message)
h_m_avg = T.stack(h_m).mean(0)
b = T.stack(b_next).mean(0) if b_fix is None else b_fix
b_new = b_fix = b
h_new = h
for i in range(self.max_recur):
b_rescaled, _ = self.glimpse.rescale(b_new[:, None], False)
g = self.glimpse(self.x, b_rescaled)[:, 0]
h_in = T.cat([self.net_h(self.cnn(g).view(batch_size, -1)), h_m_avg], -1)
h_new = self.h_to_h(h_in, h_new)
db = self.net_b(h_new)
dy = self.net_y(h_new)
b_new = b + db
y_new = y + dy
a_new = self.net_a(h_new)
return {'h': h_new, 'b': b, 'b_next': b_new, 'a': a_new, 'y': y_new, 'g': g, 'b_fix': b_fix, 'db': db}
def update_local():
pass
class ReadoutModule(nn.Module):
'''
Returns the logits of classes
'''
def __init__(self, *args, **kwarg):
super(ReadoutModule, self).__init__()
self.y = nn.Linear(kwarg['h_dims'], kwarg['n_classes'])
def forward(self, nodes_state, pretrain=False):
if pretrain:
assert len(nodes_state) == 1 # root only
h = nodes_state[0]['h']
y = self.y(h)
else:
#h = T.stack([s['h'] for s in nodes_state], 1)
#a = F.softmax(T.stack([s['a'] for s in nodes_state], 1), 1)
#b_of_h = T.sum(a * h, 1)
#b_of_h = h[:, -1]
#y = self.y(b_of_h)
#y = nodes_state[-1]['y']
y = T.stack([s['y'] for s in nodes_state], 1)
return y
class DFSGlimpseSingleObjectClassifier(nn.Module):
def __init__(self,
h_dims=128,
n_classes=10,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
cnn='cnn'
):
nn.Module.__init__(self)
#self.T_MAX_RECUR = kwarg['steps']
t = nx.balanced_tree(1, 2)
t_uni = nx.bfs_tree(t, 0)
self.G = mx_Graph(t)
self.root = 0
self.h_dims = h_dims
self.n_classes = n_classes
self.message_module = MessageModule()
self.G.register_message_func(self.message_module) # default: just copy
#self.update_module = UpdateModule(h_dims, n_classes, glimpse_size)
self.update_module = UpdateModule(
glimpse_type=glimpse_type,
glimpse_size=glimpse_size,
n_layers=6,
h_dims=h_dims,
n_classes=n_classes,
final_pool_size=final_pool_size,
filters=filters,
kernel_size=kernel_size,
cnn=cnn,
max_recur=1, # T_MAX_RECUR
)
self.G.register_update_func(self.update_module)
self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes)
self.G.register_readout_func(self.readout_module)
self.walk_list = [(0, 1), (1, 2)]
#dfs_walk(t_uni, self.root, self.walk_list)
def forward(self, x, pretrain=False):
batch_size = x.shape[0]
self.update_module.set_image(x)
self.G.init_reprs({
'h': x.new(batch_size, self.h_dims).zero_(),
'b': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'b_next': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'a': x.new(batch_size, 1).zero_(),
'y': x.new(batch_size, self.n_classes).zero_(),
'g': None,
'b_fix': None,
'db': None,
})
#TODO: the following two lines is needed for single object
#TODO: but not useful or wrong for multi-obj
self.G.recvfrom(self.root, [])
if pretrain:
return self.G.readout([self.root], pretrain=True)
else:
for u, v in self.walk_list:
self.G.update_by_edge((u, v))
# update local should be inside the update module
#for i in self.T_MAX_RECUR:
# self.G.update_local(u)
return self.G.readout('all', pretrain=False)
class Net(skorch.NeuralNet):
def __init__(self, **kwargs):
self.reg_coef_ = kwargs.get('reg_coef', 1e-4)
del kwargs['reg_coef']
skorch.NeuralNet.__init__(self, **kwargs)
def initialize_criterion(self):
# Overriding this method to skip initializing criterion as we don't use it.
pass
def get_split_datasets(self, X, y=None, **fit_params):
# Overriding this method to use our own dataloader to change the X
# in signature to (train_dataset, valid_dataset)
X_train, X_valid = X
train = self.get_dataset(X_train, None)
valid = self.get_dataset(X_valid, None)
return train, valid
def train_step(self, Xi, yi, **fit_params):
step = skorch.NeuralNet.train_step(self, Xi, yi, **fit_params)
dbs = [self.module_.G.get_repr(v)['db'] for v in self.module_.G.nodes]
reg = self.reg_coef_ * sum(db.norm(2, 1).mean() for db in dbs if db is not None)
loss = step['loss'] + reg
y_pred = step['y_pred']
acc = self.get_loss(y_pred, yi, training=False)
self.history.record_batch('max_param', max(p.abs().max().item() for p in self.module_.parameters()))
self.history.record_batch('acc', acc.item())
self.history.record_batch('reg', reg.item())
return {
'loss': loss,
'y_pred': y_pred,
}
def get_loss(self, y_pred, y_true, X=None, training=False):
batch_size, n_steps, _ = y_pred.shape
if training:
#return F.cross_entropy(y_pred, y_true)
y_true = y_true[:, None].expand(batch_size, n_steps)
return F.cross_entropy(
y_pred.reshape(batch_size * n_steps, -1),
y_true.reshape(-1)
)
else:
y_prob, y_cls = y_pred.max(-1)
_, y_prob_maxind = y_prob.max(-1)
y_cls_final = y_cls.gather(1, y_prob_maxind[:, None])[:, 0]
return (y_cls_final == y_true).sum()
class Dump(skorch.callbacks.Callback):
def initialize(self):
self.epoch = 0
self.batch = 0
self.correct = 0
self.total = 0
self.best_acc = 0
self.nviz = 0
return self
def on_epoch_begin(self, net, **kwargs):
self.epoch += 1
self.batch = 0
self.correct = 0
self.total = 0
self.nviz = 0
def on_batch_end(self, net, **kwargs):
self.batch += 1
if kwargs['training']:
#print('#', self.epoch, self.batch, kwargs['loss'], kwargs['valid_loss'])
pass
else:
self.correct += kwargs['loss'].item()
self.total += kwargs['X'].shape[0]
if self.nviz < 10:
n_nodes = len(net.module_.G.nodes)
fig, ax = init_canvas(n_nodes)
#a = T.stack([net.module_.G.get_repr(v)['a'] for v in net.module_.G.nodes], 1)
#a = F.softmax(a, 1).detach().cpu().numpy()
y = T.stack([net.module_.G.get_repr(v)['y'] for v in net.module_.G.nodes], 1)
y_val, y = y.max(-1)
for i, n in enumerate(net.module_.G.nodes):
repr_ = net.module_.G.get_repr(n)
g = repr_['g']
if g is None:
continue
b, _ = net.module_.update_module.glimpse.rescale(repr_['b'], False)
display_image(
fig,
ax,
i,
g[0],
np.array_str(
b[0].detach().cpu().numpy(),
precision=2, suppress_small=True) +
#'a=%.2f' % a[0, i, 0]
'y=%d (%.2f)' % (y[0, i], y_val[0, i])
)
wm.display_mpl_figure(fig, win='viz{}'.format(self.nviz))
self.nviz += 1
def on_epoch_end(self, net, **kwargs):
print('@', self.epoch, self.correct, '/', self.total)
acc = self.correct / self.total
if self.best_acc < acc:
self.best_acc = acc
net.history.record('acc_best', acc)
else:
net.history.record('acc_best', None)
def data_generator(dataset, batch_size, shuffle):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True, num_workers=0)
for _x, _y, _B in dataloader:
x = _x[:, None].expand(_x.shape[0], 3, _x.shape[1], _x.shape[2]).float() / 255.
y = _y.squeeze(1)
yield cuda(x), cuda(y)
if __name__ == "__main__":
from datasets import MNISTMulti
from torch.utils.data import DataLoader
from sklearn.model_selection import GridSearchCV
mnist_train = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=True)
mnist_valid = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=False, mode='valid')
for reg_coef in [0, 100, 1e-2, 0.1, 1, 1e-3]:
print('Trying reg coef', reg_coef)
net = Net(
module=DFSGlimpseSingleObjectClassifier,
criterion=None,
max_epochs=50,
reg_coef=reg_coef,
optimizer=T.optim.RMSprop,
#optimizer__weight_decay=1e-4,
lr=1e-5,
batch_size=batch_size,
device='cuda' if USE_CUDA else 'cpu',
callbacks=[
Dump(),
skorch.callbacks.Checkpoint(monitor='acc_best'),
skorch.callbacks.ProgressBar(postfix_keys=['train_loss', 'valid_loss', 'acc', 'reg']),
skorch.callbacks.GradientNormClipping(0.01),
#skorch.callbacks.LRScheduler('ReduceLROnPlateau'),
],
iterator_train=data_generator,
iterator_train__shuffle=True,
iterator_valid=data_generator,
iterator_valid__shuffle=False,
)
#net.fit((mnist_train, mnist_valid), pretrain=True, epochs=50)
net.partial_fit((mnist_train, mnist_valid), pretrain=False, epochs=500)
......@@ -80,7 +80,7 @@ class TreeLSTM(nn.Module):
iterator.append((src, trg))
frontier = src
g.recvfrom(leaves)
g.recv(leaves)
g.propagate(reversed(iterator))
return self.readout_func(g, train)
......
import torch as th
import torch.nn as nn
import torch.nn.functional as F
'''
Defult modules: this is Pytorch specific
- MessageModule: copy
- UpdateModule: vanilla RNN
- ReadoutModule: bag of words
- ReductionModule: bag of words
'''
class DefaultMessageModule(nn.Module):
"""
Default message module:
- copy
"""
def __init__(self, *args, **kwargs):
super(DefaultMessageModule, self).__init__(*args, **kwargs)
def forward(self, x):
return x
class DefaultUpdateModule(nn.Module):
"""
Default update module:
- a vanilla GRU with ReLU, or GRU
"""
def __init__(self, *args, **kwargs):
super(DefaultUpdateModule, self).__init__()
h_dims = self.h_dims = kwargs.get('h_dims', 128)
net_type = self.net_type = kwargs.get('net_type', 'fwd')
n_func = self.n_func = kwargs.get('n_func', 1)
self.f_idx = 0
self.reduce_func = DefaultReductionModule()
if net_type == 'gru':
self.net = [nn.GRUCell(h_dims, h_dims) for i in range(n_func)]
else:
self.net = [nn.Linear(2 * h_dims, h_dims) for i in range(n_func)]
def forward(self, x, msgs):
if not th.is_tensor(x):
x = th.zeros_like(msgs[0])
m = self.reduce_func(msgs)
assert(self.f_idx < self.n_func)
if self.net_type == 'gru':
out = self.net[self.f_idx](m, x)
else:
_in = th.cat((m, x), 1)
out = F.relu(self.net[self.f_idx](_in))
self.f_idx += 1
return out
def reset_f_idx(self):
self.f_idx = 0
class DefaultReductionModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReductionModule, self).__init__(*args, **kwargs)
def forward(self, x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
class DefaultReadoutModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReadoutModule, self).__init__(*args, **kwargs)
self.reduce_func = DefaultReductionModule()
def forward(self, x_s):
return self.reduce_func(x_s)
......@@ -307,21 +307,18 @@ class DGLGraph(DiGraph):
"""
self._internal_trigger_edges(u, v, __EFUNC__)
def recvfrom(self, u, preds=None):
"""Trigger the update function on node u.
def recv(self, u):
"""Receive in-coming messages and update representation on node u.
It computes the new node state using the messages and edge
states from preds->u. If `u` is one node, `preds` is a list
of predecessors. If `u` is a container or tensor of nodes,
then `preds[i]` should be the predecessors of `u[i]`.
It computes the new node state using the messages sent from the predecessors
of node u. If no message is found from the predecessors, reduce function
will be skipped and a None type will be provided as the reduced messages for
the update function.
Parameters
----------
u : node, container or tensor
The node to be updated.
preds : container
Nodes with pre-computed messages to u. Default is all
the predecessors.
"""
u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor)
......@@ -329,18 +326,18 @@ class DGLGraph(DiGraph):
ufunc = self._glb_func.get(__UFUNC__)
# TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)):
if preds is None:
v = list(self.pred[uu])
elif u_is_container or u_is_tensor:
v = preds[i]
else:
v = preds
# TODO(minjie): tensorize the message batching
# reduce phase
f_reduce = self.nodes[uu].get(__RFUNC__, rfunc)
assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu
msgs_batch = [self.edges[vv, uu][__MSG__] for vv in v]
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0:
msgs_reduced = None
elif len(msgs_batch) == 1:
msgs_reduced = msgs_batch[0]
else:
msgs_reduced = f_reduce(msgs_batch)
# update phase
f_update = self.nodes[uu].get(__UFUNC__, ufunc)
......@@ -361,17 +358,10 @@ class DGLGraph(DiGraph):
"""
self.sendto(u, v)
# TODO(minjie): tensorize the following loops.
preds = defaultdict(list)
dst = set()
for uu, vv in utils.edge_iter(u, v):
preds[vv].append(uu)
if len(preds) == 1:
dst = list(preds.keys())[0]
src = preds[dst]
self.recvfrom(dst, src)
elif len(preds) > 1:
dst = list(preds.keys())
src = [preds[d] for d in dst]
self.recvfrom(dst, src)
dst.add(vv)
self.recv(list(dst))
def update_to(self, u):
"""Pull messages from the node's predecessors and then update it.
......@@ -386,7 +376,7 @@ class DGLGraph(DiGraph):
assert uu in self.nodes
preds = list(self.pred[uu])
self.sendto(preds, uu)
self.recvfrom(uu, preds)
self.recv(uu)
def update_from(self, u):
"""Send message from the node to its successors and update them.
......@@ -409,7 +399,7 @@ class DGLGraph(DiGraph):
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
self.sendto(u, v)
self.recvfrom(list(self.nodes()))
self.recv(list(self.nodes()))
def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator.
......
......@@ -34,11 +34,11 @@ def test_sendrecv():
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, dst, edge):
......@@ -55,11 +55,11 @@ def test_hybridrepr():
g.register_update_func(update_func_hybrid)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__':
......
......@@ -30,11 +30,11 @@ def test_sendrecv():
g.register_update_func(update_func)
g.register_reduce_func('sum')
g.sendto(0, 1)
g.recvfrom(1, [0])
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9)
g.sendto(6, 9)
g.recvfrom(9, [5, 6])
g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv():
......@@ -45,15 +45,15 @@ def test_multi_sendrecv():
g.register_reduce_func('sum')
# one-many
g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]])
g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one
g.sendto([6, 7, 8], 9)
g.recvfrom(9, [6, 7, 8])
g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many
g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
g.recvfrom([4, 5, 9], [[0], [0], [4, 5]])
g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines():
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
return hu
def message_not_called(hu, hv, e_uv):
assert False
return hu
def reduce_not_called(msgs):
assert False
return 0
def update_no_msg(h, accum):
assert accum is None
return h + 1
def update_func(h, accum):
assert accum is not None
return h + accum
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
return g
def test_no_msg_update():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func('sum')
g.register_update_func(update_func)
g.sendto(1, 9)
g.sendto(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
try:
# The second recv should have a None message
g.recv(9)
except:
return
assert False
def test_recv_no_pred():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
g.recv(0)
def test_skipped_reduce():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_func)
g.sendto(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
if __name__ == '__main__':
test_no_msg_update()
test_double_recv()
test_recv_no_pred()
test_skipped_reduce()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment