Unverified Commit c50b90cf authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Change `recvfrom` API; fix corner cases in issue20 (#25)

* Change ; fix corner cases in issue20

* Change recvfrom to recv
parent d24ccfdd
import networkx as nx
from networkx.classes.digraph import DiGraph
if __name__ == '__main__':
from torch.autograd import Variable as Var
th.random.manual_seed(0)
print("testing vanilla RNN update")
g_path = mx_Graph(nx.path_graph(2))
g_path.set_repr(0, th.rand(2, 128))
g_path.sendto(0, 1)
g_path.recvfrom(1, [0])
g_path.readout()
'''
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print("testing GRU update")
g = mx_Graph(nx.path_graph(3))
update_net = DefaultUpdateModule(h_dims=4, net_type='gru')
g.register_update_func(update_net)
msg_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
g.register_message_func(msg_net)
for n in g:
g.set_repr(n, th.rand(2, 4))
y_pre = g.readout()
g.update_from(0)
y_after = g.readout()
upd_nets = DefaultUpdateModule(h_dims=4, net_type='gru', n_func=2)
g.register_update_func(upd_nets)
g.update_from(0)
g.update_from(0)
import networkx as nx
from mx import mx_Graph
from glimpse import create_glimpse
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as MODELS
import torch.nn.init as INIT
from util import USE_CUDA, cuda
import numpy as np
import skorch
from viz import VisdomWindowManager
import matplotlib.pyplot as plt
batch_size = 32
wm = VisdomWindowManager(port=10248)
def dfs_walk(tree, curr, l):
if len(tree.succ[curr]) == 0:
return
else:
for n in tree.succ[curr]:
l.append((curr, n))
dfs_walk(tree, n, l)
l.append((n, curr))
def build_cnn(**config):
cnn_list = []
filters = config['filters']
kernel_size = config['kernel_size']
in_channels = config.get('in_channels', 3)
final_pool_size = config['final_pool_size']
for i in range(len(filters)):
module = nn.Conv2d(
in_channels if i == 0 else filters[i-1],
filters[i],
kernel_size,
padding=tuple((_ - 1) // 2 for _ in kernel_size),
)
INIT.xavier_uniform_(module.weight)
INIT.constant_(module.bias, 0)
cnn_list.append(module)
if i < len(filters) - 1:
cnn_list.append(nn.LeakyReLU())
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def build_resnet_cnn(**config):
n_layers = config['n_layers']
final_pool_size = config['final_pool_size']
resnet = MODELS.resnet18(pretrained=False)
cnn_list = list(resnet.children())[0:n_layers]
cnn_list.append(nn.AdaptiveMaxPool2d(final_pool_size))
return nn.Sequential(*cnn_list)
def init_canvas(n_nodes):
fig, ax = plt.subplots(2, 4)
fig.set_size_inches(16, 8)
return fig, ax
def display_image(fig, ax, i, im, title):
im = im.detach().cpu().numpy().transpose(1, 2, 0)
ax[i // 4, i % 4].imshow(im, cmap='gray', vmin=0, vmax=1)
ax[i // 4, i % 4].set_title(title)
class MessageModule(nn.Module):
def forward(self, state):
h, b_next = [state[k] for k in ['h', 'b_next']]
return h, b_next
class UpdateModule(nn.Module):
"""
UpdateModule:
Returns:
h: new state
b: new bounding box
a: attention (for readout)
y: prediction
"""
def __init__(self, **config):
#h_dims=128,
#n_classes=10,
#steps=5,
#filters=[16, 32, 64, 128, 256],
#kernel_size=(3, 3),
#final_pool_size=(2, 2),
#glimpse_type='gaussian',
#glimpse_size=(15, 15),
#cnn='resnet'
#):
super(UpdateModule, self).__init__()
glimpse_type = config['glimpse_type']
glimpse_size = config['glimpse_size']
self.glimpse = create_glimpse(glimpse_type, glimpse_size)
h_dims = config['h_dims']
n_classes = config['n_classes']
self.net_b = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, self.glimpse.att_params),
)
self.net_y = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, n_classes),
)
self.net_a = nn.Sequential(
nn.Linear(h_dims, h_dims),
nn.ReLU(),
nn.Linear(h_dims, 1),
)
self.h_to_h = nn.GRUCell(h_dims * 2, h_dims)
INIT.orthogonal_(self.h_to_h.weight_hh)
cnn = config['cnn']
final_pool_size = config['final_pool_size']
if cnn == 'resnet':
n_layers = config['n_layers']
self.cnn = build_resnet_cnn(
n_layers=n_layers,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(128 * np.prod(final_pool_size), h_dims)
else:
filters = config['filters']
kernel_size = config['kernel_size']
self.cnn = build_cnn(
filters=filters,
kernel_size=kernel_size,
final_pool_size=final_pool_size,
)
self.net_h = nn.Linear(filters[-1] * np.prod(final_pool_size), h_dims)
self.max_recur = config.get('max_recur', 1)
self.h_dims = h_dims
def set_image(self, x):
self.x = x
def forward(self, node_state, message):
h, b, y, b_fix = [node_state[k] for k in ['h', 'b', 'y', 'b_fix']]
batch_size = h.shape[0]
if len(message) == 0:
h_m_avg = h.new(batch_size, self.h_dims).zero_()
else:
h_m, b_next = zip(*message)
h_m_avg = T.stack(h_m).mean(0)
b = T.stack(b_next).mean(0) if b_fix is None else b_fix
b_new = b_fix = b
h_new = h
for i in range(self.max_recur):
b_rescaled, _ = self.glimpse.rescale(b_new[:, None], False)
g = self.glimpse(self.x, b_rescaled)[:, 0]
h_in = T.cat([self.net_h(self.cnn(g).view(batch_size, -1)), h_m_avg], -1)
h_new = self.h_to_h(h_in, h_new)
db = self.net_b(h_new)
dy = self.net_y(h_new)
b_new = b + db
y_new = y + dy
a_new = self.net_a(h_new)
return {'h': h_new, 'b': b, 'b_next': b_new, 'a': a_new, 'y': y_new, 'g': g, 'b_fix': b_fix, 'db': db}
def update_local():
pass
class ReadoutModule(nn.Module):
'''
Returns the logits of classes
'''
def __init__(self, *args, **kwarg):
super(ReadoutModule, self).__init__()
self.y = nn.Linear(kwarg['h_dims'], kwarg['n_classes'])
def forward(self, nodes_state, pretrain=False):
if pretrain:
assert len(nodes_state) == 1 # root only
h = nodes_state[0]['h']
y = self.y(h)
else:
#h = T.stack([s['h'] for s in nodes_state], 1)
#a = F.softmax(T.stack([s['a'] for s in nodes_state], 1), 1)
#b_of_h = T.sum(a * h, 1)
#b_of_h = h[:, -1]
#y = self.y(b_of_h)
#y = nodes_state[-1]['y']
y = T.stack([s['y'] for s in nodes_state], 1)
return y
class DFSGlimpseSingleObjectClassifier(nn.Module):
def __init__(self,
h_dims=128,
n_classes=10,
filters=[16, 32, 64, 128, 256],
kernel_size=(3, 3),
final_pool_size=(2, 2),
glimpse_type='gaussian',
glimpse_size=(15, 15),
cnn='cnn'
):
nn.Module.__init__(self)
#self.T_MAX_RECUR = kwarg['steps']
t = nx.balanced_tree(1, 2)
t_uni = nx.bfs_tree(t, 0)
self.G = mx_Graph(t)
self.root = 0
self.h_dims = h_dims
self.n_classes = n_classes
self.message_module = MessageModule()
self.G.register_message_func(self.message_module) # default: just copy
#self.update_module = UpdateModule(h_dims, n_classes, glimpse_size)
self.update_module = UpdateModule(
glimpse_type=glimpse_type,
glimpse_size=glimpse_size,
n_layers=6,
h_dims=h_dims,
n_classes=n_classes,
final_pool_size=final_pool_size,
filters=filters,
kernel_size=kernel_size,
cnn=cnn,
max_recur=1, # T_MAX_RECUR
)
self.G.register_update_func(self.update_module)
self.readout_module = ReadoutModule(h_dims=h_dims, n_classes=n_classes)
self.G.register_readout_func(self.readout_module)
self.walk_list = [(0, 1), (1, 2)]
#dfs_walk(t_uni, self.root, self.walk_list)
def forward(self, x, pretrain=False):
batch_size = x.shape[0]
self.update_module.set_image(x)
self.G.init_reprs({
'h': x.new(batch_size, self.h_dims).zero_(),
'b': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'b_next': x.new(batch_size, self.update_module.glimpse.att_params).zero_(),
'a': x.new(batch_size, 1).zero_(),
'y': x.new(batch_size, self.n_classes).zero_(),
'g': None,
'b_fix': None,
'db': None,
})
#TODO: the following two lines is needed for single object
#TODO: but not useful or wrong for multi-obj
self.G.recvfrom(self.root, [])
if pretrain:
return self.G.readout([self.root], pretrain=True)
else:
for u, v in self.walk_list:
self.G.update_by_edge((u, v))
# update local should be inside the update module
#for i in self.T_MAX_RECUR:
# self.G.update_local(u)
return self.G.readout('all', pretrain=False)
class Net(skorch.NeuralNet):
def __init__(self, **kwargs):
self.reg_coef_ = kwargs.get('reg_coef', 1e-4)
del kwargs['reg_coef']
skorch.NeuralNet.__init__(self, **kwargs)
def initialize_criterion(self):
# Overriding this method to skip initializing criterion as we don't use it.
pass
def get_split_datasets(self, X, y=None, **fit_params):
# Overriding this method to use our own dataloader to change the X
# in signature to (train_dataset, valid_dataset)
X_train, X_valid = X
train = self.get_dataset(X_train, None)
valid = self.get_dataset(X_valid, None)
return train, valid
def train_step(self, Xi, yi, **fit_params):
step = skorch.NeuralNet.train_step(self, Xi, yi, **fit_params)
dbs = [self.module_.G.get_repr(v)['db'] for v in self.module_.G.nodes]
reg = self.reg_coef_ * sum(db.norm(2, 1).mean() for db in dbs if db is not None)
loss = step['loss'] + reg
y_pred = step['y_pred']
acc = self.get_loss(y_pred, yi, training=False)
self.history.record_batch('max_param', max(p.abs().max().item() for p in self.module_.parameters()))
self.history.record_batch('acc', acc.item())
self.history.record_batch('reg', reg.item())
return {
'loss': loss,
'y_pred': y_pred,
}
def get_loss(self, y_pred, y_true, X=None, training=False):
batch_size, n_steps, _ = y_pred.shape
if training:
#return F.cross_entropy(y_pred, y_true)
y_true = y_true[:, None].expand(batch_size, n_steps)
return F.cross_entropy(
y_pred.reshape(batch_size * n_steps, -1),
y_true.reshape(-1)
)
else:
y_prob, y_cls = y_pred.max(-1)
_, y_prob_maxind = y_prob.max(-1)
y_cls_final = y_cls.gather(1, y_prob_maxind[:, None])[:, 0]
return (y_cls_final == y_true).sum()
class Dump(skorch.callbacks.Callback):
def initialize(self):
self.epoch = 0
self.batch = 0
self.correct = 0
self.total = 0
self.best_acc = 0
self.nviz = 0
return self
def on_epoch_begin(self, net, **kwargs):
self.epoch += 1
self.batch = 0
self.correct = 0
self.total = 0
self.nviz = 0
def on_batch_end(self, net, **kwargs):
self.batch += 1
if kwargs['training']:
#print('#', self.epoch, self.batch, kwargs['loss'], kwargs['valid_loss'])
pass
else:
self.correct += kwargs['loss'].item()
self.total += kwargs['X'].shape[0]
if self.nviz < 10:
n_nodes = len(net.module_.G.nodes)
fig, ax = init_canvas(n_nodes)
#a = T.stack([net.module_.G.get_repr(v)['a'] for v in net.module_.G.nodes], 1)
#a = F.softmax(a, 1).detach().cpu().numpy()
y = T.stack([net.module_.G.get_repr(v)['y'] for v in net.module_.G.nodes], 1)
y_val, y = y.max(-1)
for i, n in enumerate(net.module_.G.nodes):
repr_ = net.module_.G.get_repr(n)
g = repr_['g']
if g is None:
continue
b, _ = net.module_.update_module.glimpse.rescale(repr_['b'], False)
display_image(
fig,
ax,
i,
g[0],
np.array_str(
b[0].detach().cpu().numpy(),
precision=2, suppress_small=True) +
#'a=%.2f' % a[0, i, 0]
'y=%d (%.2f)' % (y[0, i], y_val[0, i])
)
wm.display_mpl_figure(fig, win='viz{}'.format(self.nviz))
self.nviz += 1
def on_epoch_end(self, net, **kwargs):
print('@', self.epoch, self.correct, '/', self.total)
acc = self.correct / self.total
if self.best_acc < acc:
self.best_acc = acc
net.history.record('acc_best', acc)
else:
net.history.record('acc_best', None)
def data_generator(dataset, batch_size, shuffle):
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True, num_workers=0)
for _x, _y, _B in dataloader:
x = _x[:, None].expand(_x.shape[0], 3, _x.shape[1], _x.shape[2]).float() / 255.
y = _y.squeeze(1)
yield cuda(x), cuda(y)
if __name__ == "__main__":
from datasets import MNISTMulti
from torch.utils.data import DataLoader
from sklearn.model_selection import GridSearchCV
mnist_train = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=True)
mnist_valid = MNISTMulti('.', n_digits=1, backrand=0, image_rows=200, image_cols=200, download=False, mode='valid')
for reg_coef in [0, 100, 1e-2, 0.1, 1, 1e-3]:
print('Trying reg coef', reg_coef)
net = Net(
module=DFSGlimpseSingleObjectClassifier,
criterion=None,
max_epochs=50,
reg_coef=reg_coef,
optimizer=T.optim.RMSprop,
#optimizer__weight_decay=1e-4,
lr=1e-5,
batch_size=batch_size,
device='cuda' if USE_CUDA else 'cpu',
callbacks=[
Dump(),
skorch.callbacks.Checkpoint(monitor='acc_best'),
skorch.callbacks.ProgressBar(postfix_keys=['train_loss', 'valid_loss', 'acc', 'reg']),
skorch.callbacks.GradientNormClipping(0.01),
#skorch.callbacks.LRScheduler('ReduceLROnPlateau'),
],
iterator_train=data_generator,
iterator_train__shuffle=True,
iterator_valid=data_generator,
iterator_valid__shuffle=False,
)
#net.fit((mnist_train, mnist_valid), pretrain=True, epochs=50)
net.partial_fit((mnist_train, mnist_valid), pretrain=False, epochs=500)
...@@ -80,7 +80,7 @@ class TreeLSTM(nn.Module): ...@@ -80,7 +80,7 @@ class TreeLSTM(nn.Module):
iterator.append((src, trg)) iterator.append((src, trg))
frontier = src frontier = src
g.recvfrom(leaves) g.recv(leaves)
g.propagate(reversed(iterator)) g.propagate(reversed(iterator))
return self.readout_func(g, train) return self.readout_func(g, train)
......
import torch as th
import torch.nn as nn
import torch.nn.functional as F
'''
Defult modules: this is Pytorch specific
- MessageModule: copy
- UpdateModule: vanilla RNN
- ReadoutModule: bag of words
- ReductionModule: bag of words
'''
class DefaultMessageModule(nn.Module):
"""
Default message module:
- copy
"""
def __init__(self, *args, **kwargs):
super(DefaultMessageModule, self).__init__(*args, **kwargs)
def forward(self, x):
return x
class DefaultUpdateModule(nn.Module):
"""
Default update module:
- a vanilla GRU with ReLU, or GRU
"""
def __init__(self, *args, **kwargs):
super(DefaultUpdateModule, self).__init__()
h_dims = self.h_dims = kwargs.get('h_dims', 128)
net_type = self.net_type = kwargs.get('net_type', 'fwd')
n_func = self.n_func = kwargs.get('n_func', 1)
self.f_idx = 0
self.reduce_func = DefaultReductionModule()
if net_type == 'gru':
self.net = [nn.GRUCell(h_dims, h_dims) for i in range(n_func)]
else:
self.net = [nn.Linear(2 * h_dims, h_dims) for i in range(n_func)]
def forward(self, x, msgs):
if not th.is_tensor(x):
x = th.zeros_like(msgs[0])
m = self.reduce_func(msgs)
assert(self.f_idx < self.n_func)
if self.net_type == 'gru':
out = self.net[self.f_idx](m, x)
else:
_in = th.cat((m, x), 1)
out = F.relu(self.net[self.f_idx](_in))
self.f_idx += 1
return out
def reset_f_idx(self):
self.f_idx = 0
class DefaultReductionModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReductionModule, self).__init__(*args, **kwargs)
def forward(self, x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
class DefaultReadoutModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReadoutModule, self).__init__(*args, **kwargs)
self.reduce_func = DefaultReductionModule()
def forward(self, x_s):
return self.reduce_func(x_s)
...@@ -307,21 +307,18 @@ class DGLGraph(DiGraph): ...@@ -307,21 +307,18 @@ class DGLGraph(DiGraph):
""" """
self._internal_trigger_edges(u, v, __EFUNC__) self._internal_trigger_edges(u, v, __EFUNC__)
def recvfrom(self, u, preds=None): def recv(self, u):
"""Trigger the update function on node u. """Receive in-coming messages and update representation on node u.
It computes the new node state using the messages and edge It computes the new node state using the messages sent from the predecessors
states from preds->u. If `u` is one node, `preds` is a list of node u. If no message is found from the predecessors, reduce function
of predecessors. If `u` is a container or tensor of nodes, will be skipped and a None type will be provided as the reduced messages for
then `preds[i]` should be the predecessors of `u[i]`. the update function.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node to be updated. The node to be updated.
preds : container
Nodes with pre-computed messages to u. Default is all
the predecessors.
""" """
u_is_container = isinstance(u, list) u_is_container = isinstance(u, list)
u_is_tensor = isinstance(u, Tensor) u_is_tensor = isinstance(u, Tensor)
...@@ -329,19 +326,19 @@ class DGLGraph(DiGraph): ...@@ -329,19 +326,19 @@ class DGLGraph(DiGraph):
ufunc = self._glb_func.get(__UFUNC__) ufunc = self._glb_func.get(__UFUNC__)
# TODO(minjie): tensorize the loop. # TODO(minjie): tensorize the loop.
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
if preds is None:
v = list(self.pred[uu])
elif u_is_container or u_is_tensor:
v = preds[i]
else:
v = preds
# TODO(minjie): tensorize the message batching # TODO(minjie): tensorize the message batching
# reduce phase # reduce phase
f_reduce = self.nodes[uu].get(__RFUNC__, rfunc) f_reduce = self.nodes[uu].get(__RFUNC__, rfunc)
assert f_reduce is not None, \ assert f_reduce is not None, \
"Reduce function not registered for node %s" % uu "Reduce function not registered for node %s" % uu
msgs_batch = [self.edges[vv, uu][__MSG__] for vv in v] msgs_batch = [self.edges[vv, uu].pop(__MSG__)
msgs_reduced = f_reduce(msgs_batch) for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0:
msgs_reduced = None
elif len(msgs_batch) == 1:
msgs_reduced = msgs_batch[0]
else:
msgs_reduced = f_reduce(msgs_batch)
# update phase # update phase
f_update = self.nodes[uu].get(__UFUNC__, ufunc) f_update = self.nodes[uu].get(__UFUNC__, ufunc)
assert f_update is not None, \ assert f_update is not None, \
...@@ -361,17 +358,10 @@ class DGLGraph(DiGraph): ...@@ -361,17 +358,10 @@ class DGLGraph(DiGraph):
""" """
self.sendto(u, v) self.sendto(u, v)
# TODO(minjie): tensorize the following loops. # TODO(minjie): tensorize the following loops.
preds = defaultdict(list) dst = set()
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
preds[vv].append(uu) dst.add(vv)
if len(preds) == 1: self.recv(list(dst))
dst = list(preds.keys())[0]
src = preds[dst]
self.recvfrom(dst, src)
elif len(preds) > 1:
dst = list(preds.keys())
src = [preds[d] for d in dst]
self.recvfrom(dst, src)
def update_to(self, u): def update_to(self, u):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
...@@ -386,7 +376,7 @@ class DGLGraph(DiGraph): ...@@ -386,7 +376,7 @@ class DGLGraph(DiGraph):
assert uu in self.nodes assert uu in self.nodes
preds = list(self.pred[uu]) preds = list(self.pred[uu])
self.sendto(preds, uu) self.sendto(preds, uu)
self.recvfrom(uu, preds) self.recv(uu)
def update_from(self, u): def update_from(self, u):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
...@@ -409,7 +399,7 @@ class DGLGraph(DiGraph): ...@@ -409,7 +399,7 @@ class DGLGraph(DiGraph):
u = [uu for uu, _ in self.edges] u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges] v = [vv for _, vv in self.edges]
self.sendto(u, v) self.sendto(u, v)
self.recvfrom(list(self.nodes())) self.recv(list(self.nodes()))
def propagate(self, iterator='bfs', **kwargs): def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
......
...@@ -34,11 +34,11 @@ def test_sendrecv(): ...@@ -34,11 +34,11 @@ def test_sendrecv():
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum') g.register_reduce_func('sum')
g.sendto(0, 1) g.sendto(0, 1)
g.recvfrom(1, [0]) g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9) g.sendto(5, 9)
g.sendto(6, 9) g.sendto(6, 9)
g.recvfrom(9, [5, 6]) g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25]) check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, dst, edge): def message_func_hybrid(src, dst, edge):
...@@ -55,11 +55,11 @@ def test_hybridrepr(): ...@@ -55,11 +55,11 @@ def test_hybridrepr():
g.register_update_func(update_func_hybrid) g.register_update_func(update_func_hybrid)
g.register_reduce_func('sum') g.register_reduce_func('sum')
g.sendto(0, 1) g.sendto(0, 1)
g.recvfrom(1, [0]) g.recv(1)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9) g.sendto(5, 9)
g.sendto(6, 9) g.sendto(6, 9)
g.recvfrom(9, [5, 6]) g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25]) check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -30,11 +30,11 @@ def test_sendrecv(): ...@@ -30,11 +30,11 @@ def test_sendrecv():
g.register_update_func(update_func) g.register_update_func(update_func)
g.register_reduce_func('sum') g.register_reduce_func('sum')
g.sendto(0, 1) g.sendto(0, 1)
g.recvfrom(1, [0]) g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
g.sendto(5, 9) g.sendto(5, 9)
g.sendto(6, 9) g.sendto(6, 9)
g.recvfrom(9, [5, 6]) g.recv(9)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23]) check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])
def test_multi_sendrecv(): def test_multi_sendrecv():
...@@ -45,15 +45,15 @@ def test_multi_sendrecv(): ...@@ -45,15 +45,15 @@ def test_multi_sendrecv():
g.register_reduce_func('sum') g.register_reduce_func('sum')
# one-many # one-many
g.sendto(0, [1, 2, 3]) g.sendto(0, [1, 2, 3])
g.recvfrom([1, 2, 3], [[0], [0], [0]]) g.recv([1, 2, 3])
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10]) check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
# many-one # many-one
g.sendto([6, 7, 8], 9) g.sendto([6, 7, 8], 9)
g.recvfrom(9, [6, 7, 8]) g.recv(9)
check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34]) check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
# many-many # many-many
g.sendto([0, 0, 4, 5], [4, 5, 9, 9]) g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
g.recvfrom([4, 5, 9], [[0], [0], [4, 5]]) g.recv([4, 5, 9])
check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45]) check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])
def test_update_routines(): def test_update_routines():
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
return hu
def message_not_called(hu, hv, e_uv):
assert False
return hu
def reduce_not_called(msgs):
assert False
return 0
def update_no_msg(h, accum):
assert accum is None
return h + 1
def update_func(h, accum):
assert accum is not None
return h + accum
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
return g
def test_no_msg_update():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
for i in range(10):
g.recv(i)
check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
def test_double_recv():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func('sum')
g.register_update_func(update_func)
g.sendto(1, 9)
g.sendto(2, 9)
g.recv(9)
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
try:
# The second recv should have a None message
g.recv(9)
except:
return
assert False
def test_recv_no_pred():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_not_called)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_no_msg)
g.recv(0)
def test_skipped_reduce():
g = generate_graph()
check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
g.register_message_func(message_func)
g.register_reduce_func(reduce_not_called)
g.register_update_func(update_func)
g.sendto(0, 1)
g.recv(1)
check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
if __name__ == '__main__':
test_no_msg_update()
test_double_recv()
test_recv_no_pred()
test_skipped_reduce()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment