Commit 5567f4a4 authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[BACKEND] Add MXNet backend. (#77)

* support mxnet.

* add mxnet version of GCN.

* rename mxnet.nd as F.

* add mxnet GAT.

* enable GPU for GCN.

* fix MXNet GCN train.

* Use adam to optimize GAT

* support more operators.

* support sparse arrays.

* update mxnet backend.

* support index_copy.

* remove NN.

* update mxnet backend.

* temp check in.

* fix data conversion.

* add test.

* clean up mxnet backend.

* update mxnet examples.

* Revert "remove NN."

This reverts commit d815d9a0ec619f9ce9099c48cd35db9d8e947483.

* temp disable MXNet version of NN.
parent f31b6fd2
"""
Graph Attention Networks
Paper: https://arxiv.org/abs/1710.10903
Code: https://github.com/PetarV-/GAT
GAT with batch processing
"""
import argparse
import numpy as np
import time
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def elu(data):
return mx.nd.LeakyReLU(data, act_type='elu')
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
class GATReduce(gluon.Block):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = mx.nd.softmax(mx.nd.LeakyReLU(a))
if self.attn_drop != 0.0:
e = mx.nd.Dropout(e, self.attn_drop)
return {'accum' : mx.nd.sum(e * ft, axis=1)} # shape (B, D)
class GATFinalize(gluon.Block):
def __init__(self, headid, indim, hiddendim, activation, residual):
super(GATFinalize, self).__init__()
self.headid = headid
self.activation = activation
self.residual = residual
self.residual_fc = None
if residual:
if indim != hiddendim:
self.residual_fc = gluon.nn.Dense(hiddendim)
def forward(self, node):
ret = node['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
else:
ret = node['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(gluon.Block):
def __init__(self, indim, hiddendim, drop):
super(GATPrepare, self).__init__()
self.fc = gluon.nn.Dense(hiddendim)
self.drop = drop
self.attn_l = gluon.nn.Dense(1)
self.attn_r = gluon.nn.Dense(1)
def forward(self, feats):
h = feats
if self.drop != 0.0:
h = mx.nd.Dropout(h, self.drop)
ft = self.fc(h)
a1 = self.attn_l(ft)
a2 = self.attn_r(ft)
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2}
class GAT(gluon.Block):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
num_heads,
activation,
in_drop,
attn_drop,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.num_heads = num_heads
self.prp = gluon.nn.Sequential()
self.red = gluon.nn.Sequential()
self.fnl = gluon.nn.Sequential()
# input projection (no residual)
for hid in range(num_heads):
self.prp.add(GATPrepare(in_dim, num_hidden, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(hid, in_dim, num_hidden, activation, False))
# hidden layers
for l in range(num_layers - 1):
for hid in range(num_heads):
# due to multi-head, the in_dim = num_hidden * num_heads
self.prp.add(GATPrepare(num_hidden * num_heads, num_hidden, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(hid, num_hidden * num_heads,
num_hidden, activation, residual))
# output projection
self.prp.add(GATPrepare(num_hidden * num_heads, num_classes, in_drop))
self.red.add(GATReduce(attn_drop))
self.fnl.add(GATFinalize(0, num_hidden * num_heads,
num_classes, activation, residual))
# sanity check
assert len(self.prp) == self.num_layers * self.num_heads + 1
assert len(self.red) == self.num_layers * self.num_heads + 1
assert len(self.fnl) == self.num_layers * self.num_heads + 1
def forward(self, features):
last = features
for l in range(self.num_layers):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
last = mx.nd.concat(
*[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
def main(args):
# load and preprocess dataset
data = load_data(args)
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
mask = mx.nd.array(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
g = DGLGraph(data.graph)
# create model
model = GAT(g,
args.num_layers,
in_feats,
args.num_hidden,
n_classes,
args.num_heads,
elu,
args.in_drop,
args.attn_drop,
args.residual)
if cuda:
model.cuda()
model.initialize()
# use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
# initialize graph
dur = []
for epoch in range(args.epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
logits = model(features)
loss = mx.nd.softmax_cross_entropy(logits, labels)
#optimizer.zero_grad()
loss.backward()
trainer.step(features.shape[0])
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GAT')
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="Which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=3,
help="number of attentional heads to use")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="size of hidden units")
parser.add_argument("--residual", action="store_false",
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.005,
help="learning rate")
args = parser.parse_args()
print(args)
main(args)
Graph Convolutional Networks (GCN)
============
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
The folder contains three different implementations using DGL.
Naive GCN (gcn.py)
-------
The model is defined in the finest granularity (aka on *one* edge and *one* node).
* The message function `gcn_msg` computes the message for one edge. It simply returns the `h` representation of the source node.
```python
def gcn_msg(src, edge):
# src['h'] is a tensor of shape (D,). D is the feature length.
return src['h']
```
* The reduce function `gcn_reduce` accumulates the incoming messages for one node. The `msgs` argument is a list of all the messages. In GCN, the incoming messages are summed up.
```python
def gcn_reduce(node, msgs):
# msgs is a list of in-coming messages.
return sum(msgs)
```
* The update function `NodeUpdateModule` computes the new new node representation `h` using non-linear transformation on the reduced messages.
```python
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
# accum is a tensor of shape (D,).
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return {'h' : h}
```
After defining the functions on each node/edge, the message passing is triggered by calling `update_all` on the DGLGraph object (in GCN module).
Batched GCN (gcn_batch.py)
-----------
Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges.
* The message function `gcn_msg` computes the message for a batch of edges. Here, the `src` argument is the batched representation of the source endpoints of the edges. The function simply returns the source node representations.
```python
def gcn_msg(src, edge):
# src is a tensor of shape (B, D). B is the number of edges being batched.
return src
```
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument:
```python
def gcn_reduce(node, msgs):
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch;
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing").
# Reduce on the second dimension is equal to sum up all the in-coming messages.
return torch.sum(msgs, 1)
```
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN.
Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching.
```python
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)`
```
Batched GCN with spMV optimization (gcn_spmv.py)
-----------
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows:
```python
self.g.update_all('from_src', 'sum', layer, batchable=True)
```
Here, `'from_src'` and `'sum'` are the builtin message and reduce function.
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with batch processing
"""
import argparse
import numpy as np
import time
import mxnet as mx
from mxnet import gluon
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src
def gcn_reduce(node, msgs):
return mx.nd.sum(msgs, 1)
class NodeUpdateModule(gluon.Block):
def __init__(self, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = gluon.nn.Dense(out_feats, activation=activation)
def forward(self, node):
return self.linear(node)
class GCN(gluon.Block):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(GCN, self).__init__()
self.g = g
self.dropout = dropout
# input layer
self.layers = gluon.nn.Sequential()
self.layers.add(NodeUpdateModule(n_hidden, activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.add(NodeUpdateModule(n_hidden, activation))
# output layer
self.layers.add(NodeUpdateModule(n_classes))
def forward(self, features):
self.g.set_n_repr(features)
for layer in self.layers:
# apply dropout
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr()
def main(args):
# load and preprocess dataset
data = load_data(args)
features = mx.nd.array(data.features)
labels = mx.nd.array(data.labels)
mask = mx.nd.array(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
if args.gpu <= 0:
cuda = False
ctx = mx.cpu(0)
else:
cuda = True
features = features.as_in_context(mx.gpu(0))
labels = labels.as_in_context(mx.gpu(0))
mask = mask.as_in_context(mx.gpu(0))
ctx = mx.gpu(0)
# create GCN model
g = DGLGraph(data.graph)
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
'relu',
args.dropout)
model.initialize(ctx=ctx)
# use optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr})
# initialize graph
dur = []
for epoch in range(args.n_epochs):
if epoch >= 3:
t0 = time.time()
# forward
with mx.autograd.record():
logits = model(features)
loss = mx.nd.softmax_cross_entropy(logits, labels)
#optimizer.zero_grad()
loss.backward()
trainer.step(features.shape[0])
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN')
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden gcn layers")
args = parser.parse_args()
main(args)
...@@ -7,5 +7,7 @@ if __backend__ == 'numpy': ...@@ -7,5 +7,7 @@ if __backend__ == 'numpy':
from .numpy import * from .numpy import *
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from .pytorch import * from .pytorch import *
elif __backend__ == 'mxnet':
from .mxnet import *
else: else:
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
from __future__ import absolute_import
import numpy as np
import mxnet as mx
import mxnet.ndarray as F
import scipy.sparse
import ctypes
from .._ffi.base import _LIB, check_call, c_array
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
# Tensor types
Tensor = mx.nd.NDArray
SparseTensor = mx.nd.sparse.CSRNDArray
# Data types
float16 = np.float16
float32 = np.float32
float64 = np.float64
uint8 = np.uint8
int8 = np.int8
int16 = np.int16
int32 = np.int32
int64 = np.int64
# Operators
tensor = mx.nd.array
sum = F.sum
def max(x):
return F.max(x).asnumpy()[0]
def sparse_tensor(idx, data, shape):
return mx.nd.sparse.csr_matrix((data, (idx[0], idx[1])), tuple(shape))
def astype(a, ty):
return F.cast(a, ty)
def asnumpy(a):
return a.asnumpy()
def from_numpy(np_data):
return mx.nd.array(np_data, dtype=np_data.dtype)
def pack(tensors):
return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1):
return th.split(x, indices_or_sections)
# TODO this doesn't exist for symbol.
def shape(x):
return x.shape
def dtype(x):
return x.dtype
def isinteger(x):
return x.dtype in [np.int, np.int8, np.int16, np.int32, np.int64]
def unique(x):
# TODO this isn't the best way of running unique.
tmp = x.asnumpy()
tmp = np.unique(tmp)
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index):
return data[row_index,]
scatter_row = mx.nd.contrib.index_copy
def broadcast_to(x, to_array):
return x + F.zeros_like(to_array)
squeeze = F.squeeze
unsqueeze = F.expand_dims
# TODO this doesn't exist for symbol.
reshape = F.reshape
ones = F.ones
zeros = F.zeros
arange = F.arange
def spmm(spm, mat):
return mx.nd.dot(spm, mat)
def sort(x, dim=None, descending=False):
if dim is None:
dim = -1
ascend = not descending
# TODO this isn't an ideal implementation.
val = F.sort(x, axis=dim, is_ascend=ascend)
idx = F.argsort(x, axis=dim, is_ascend=ascend)
idx = F.cast(idx, dtype='int64')
return val, idx
def to_context(x, ctx):
if ctx is None:
return x
elif ctx.device_type == TVMContext.STR2MASK['cuda']:
return x.as_in_context(mx.gpu(ctx.device_id))
elif ctx.device_type == TVMContext.STR2MASK['cpu']:
return x.as_in_context(mx.cpu())
else:
raise RuntimeError('Invalid context', ctx)
def get_context(x):
if x.context.device_type == 'cpu':
return TVMContext(TVMContext.STR2MASK['cpu'], 0)
else:
return TVMContext(
TVMContext.STR2MASK[x.context.device_type], x.context.device_id)
def _typestr(arr_dtype):
return arr_dtype
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read()
def zerocopy_from_dlpack(dlpack_arr):
"""Return a tensor using zero copy."""
return mx.nd.from_dlpack(dlpack_arr)
def zerocopy_to_numpy(arr):
"""Return a numpy array that shares the data."""
return arr.asnumpy()
def zerocopy_from_numpy(np_data):
"""Return a tensor that shares the numpy data."""
return mx.nd.array(np_data, dtype=np_data.dtype)
...@@ -9,5 +9,5 @@ if __backend__ == 'numpy': ...@@ -9,5 +9,5 @@ if __backend__ == 'numpy':
pass pass
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from .pytorch import * from .pytorch import *
else: elif __backend__ != 'mxnet':
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
from dgl.graph import DGLGraph
D = 5
reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert mx.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : mx.nd.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['m']}
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.set_n_repr({'h' : ncol})
return g
def test_batch_setter_getter():
def _pfc(x):
return list(x.asnumpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes
# TODO we need to enable the test later.
'''
u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64')
print(g.get_n_repr(u)['h'])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
'''
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
# set partial edges (many-many)
u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((5, D))}, u, v)
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# set partial edges (many-one)
u = mx.nd.array([3, 4, 6], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# set partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([4, 5, 6], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
# get partial edges (many-many)
u = mx.nd.array([0, 6, 0], dtype='int64')
v = mx.nd.array([6, 9, 7], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = mx.nd.array([5, 6, 7], dtype='int64')
v = mx.nd.array([9], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([3, 4, 5], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
with mx.autograd.record():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
# partial set
v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D))
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2.backward(mx.nd.ones((10, D)) * 2)
check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg)
# many-many send
u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
# one-many send
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
# many-one send
u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.send(u, v)
def test_batch_recv():
# basic recv test
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
reduce_msg_shapes.clear()
g.send(u, v)
#g.recv(th.unique(v))
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
g.register_apply_node_func(apply_node_func)
# send_and_recv
reduce_msg_shapes.clear()
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
g.send_and_recv(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# pull
v = mx.nd.array([1, 2, 3, 9], dtype='int64')
reduce_msg_shapes.clear()
g.pull(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# push
v = mx.nd.array([0, 1, 2, 3], dtype='int64')
reduce_msg_shapes.clear()
g.push(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
# update_all
reduce_msg_shapes.clear()
g.update_all()
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
def test_reduce_0deg():
g = DGLGraph()
g.add_nodes(5)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return node + msgs.sum(1)
old_repr = mx.nd.random.normal(shape=(5, 5))
g.set_n_repr(old_repr)
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return src
def _reduce(node, msgs):
assert msgs is not None
return msgs.sum(1)
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr(old_repr)
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr(old_repr)
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
if __name__ == '__main__':
test_batch_setter_getter()
# TODO we need to enable it after index_copy is implemented.
#test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment