Unverified Commit 9c135fd5 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #4 from jermainewang/master

Sync with latest commit
parents 9d3f299d 00add9f2
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file workspace_pool.h * \file workspace_pool.h
* \brief Workspace pool utility. * \brief Workspace pool utility.
*/ */
#ifndef TVM_RUNTIME_WORKSPACE_POOL_H_ #ifndef DGL_RUNTIME_WORKSPACE_POOL_H_
#define TVM_RUNTIME_WORKSPACE_POOL_H_ #define DGL_RUNTIME_WORKSPACE_POOL_H_
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <vector> #include <vector>
...@@ -58,4 +58,4 @@ class WorkspacePool { ...@@ -58,4 +58,4 @@ class WorkspacePool {
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_WORKSPACE_POOL_H_ #endif // DGL_RUNTIME_WORKSPACE_POOL_H_
// DGL Scheduler implementation /*!
* Copyright (c) 2018 by Contributors
* \file scheduler/scheduler.cc
* \brief DGL Scheduler implementation
*/
#include <dgl/scheduler.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <dgl/scheduler.h>
namespace dgl { namespace dgl {
namespace sched { namespace sched {
...@@ -19,7 +22,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& vids) { ...@@ -19,7 +22,7 @@ std::vector<IdArray> DegreeBucketing(const IdArray& vids) {
// bkt: deg->dsts // bkt: deg->dsts
std::unordered_map<int64_t, std::vector<int64_t>> bkt; std::unordered_map<int64_t, std::vector<int64_t>> bkt;
for (auto& it: in_edges) { for (const auto& it : in_edges) {
bkt[it.second.size()].push_back(it.first); bkt[it.second.size()].push_back(it.first);
} }
...@@ -38,15 +41,15 @@ std::vector<IdArray> DegreeBucketing(const IdArray& vids) { ...@@ -38,15 +41,15 @@ std::vector<IdArray> DegreeBucketing(const IdArray& vids) {
int64_t* msec_ptr = static_cast<int64_t*>(mid_section->data); int64_t* msec_ptr = static_cast<int64_t*>(mid_section->data);
// fill in bucketing ordering // fill in bucketing ordering
for (auto& it: bkt) { // for each bucket for (const auto& it : bkt) { // for each bucket
int64_t deg = it.first; const int64_t deg = it.first;
int64_t n_dst = it.second.size(); const int64_t n_dst = it.second.size();
*deg_ptr++ = deg; *deg_ptr++ = deg;
*nsec_ptr++ = n_dst; *nsec_ptr++ = n_dst;
*msec_ptr++ = deg * n_dst; *msec_ptr++ = deg * n_dst;
for (auto dst: it.second) { // for each dst in this bucket for (const auto dst : it.second) { // for each dst in this bucket
*nid_ptr++ = dst; *nid_ptr++ = dst;
for (auto mid: in_edges[dst]) { // for each in edge of dst for (const auto mid : in_edges[dst]) { // for each in edge of dst
*mid_ptr++ = mid; *mid_ptr++ = mid;
} }
} }
......
#include "../c_api_common.h" /*!
* Copyright (c) 2018 by Contributors
* \file scheduler/scheduler_apis.cc
* \brief DGL scheduler APIs
*/
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/scheduler.h> #include <dgl/scheduler.h>
#include "../c_api_common.h"
using tvm::runtime::TVMArgs; using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue; using tvm::runtime::TVMRetValue;
...@@ -18,7 +23,7 @@ TVM_REGISTER_GLOBAL("scheduler._CAPI_DGLDegreeBucketingFromGraph") ...@@ -18,7 +23,7 @@ TVM_REGISTER_GLOBAL("scheduler._CAPI_DGLDegreeBucketingFromGraph")
.set_body([] (TVMArgs args, TVMRetValue* rv) { .set_body([] (TVMArgs args, TVMRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
const Graph* gptr = static_cast<Graph*>(ghandle); const Graph* gptr = static_cast<Graph*>(ghandle);
auto edges = gptr->Edges(false); const auto& edges = gptr->Edges(false);
*rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(edges.dst)); *rv = ConvertNDArrayVectorToPackedFunc(sched::DegreeBucketing(edges.dst));
}); });
......
...@@ -9,7 +9,7 @@ reduce_msg_shapes = set() ...@@ -9,7 +9,7 @@ reduce_msg_shapes = set()
def check_eq(a, b): def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert mx.sum(a == b) == int(np.prod(list(a.shape))) assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(src, edge):
assert len(src['h'].shape) == 2 assert len(src['h'].shape) == 2
...@@ -53,16 +53,12 @@ def test_batch_setter_getter(): ...@@ -53,16 +53,12 @@ def test_batch_setter_getter():
assert len(g.get_n_repr()) == 0 assert len(g.get_n_repr()) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes # set partial nodes
# TODO we need to enable the test later.
'''
u = mx.nd.array([1, 3, 5], dtype='int64') u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u) g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64') u = mx.nd.array([1, 2, 3], dtype='int64')
print(g.get_n_repr(u)['h'])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
'''
''' '''
s, d, eid s, d, eid
...@@ -127,9 +123,11 @@ def test_batch_setter_autograd(): ...@@ -127,9 +123,11 @@ def test_batch_setter_autograd():
with mx.autograd.record(): with mx.autograd.record():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.get_n_repr()['h']
h1.attach_grad()
# partial set # partial set
v = mx.nd.array([1, 2, 8], dtype='int64') v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D)) hh = mx.nd.zeros((len(v), D))
hh.attach_grad()
g.set_n_repr({'h' : hh}, v) g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h'] h2 = g.get_n_repr()['h']
h2.backward(mx.nd.ones((10, D)) * 2) h2.backward(mx.nd.ones((10, D)) * 2)
...@@ -252,8 +250,7 @@ def test_pull_0deg(): ...@@ -252,8 +250,7 @@ def test_pull_0deg():
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
# TODO we need to enable it after index_copy is implemented. test_batch_setter_autograd()
#test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_routines() test_update_routines()
......
...@@ -20,22 +20,26 @@ def reduce_func(node, msgs): ...@@ -20,22 +20,26 @@ def reduce_func(node, msgs):
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'m' : th.sum(msgs, 1)} return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(node):
return {'h' : node['h'] + node['m']} return {'h' : node['h'] + node['accum']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
# 17 edges
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
g.add_edge(i, 9) g.add_edge(i, 9)
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol}) g.set_n_repr({'h' : ncol})
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -46,8 +50,9 @@ def test_batch_setter_getter(): ...@@ -46,8 +50,9 @@ def test_batch_setter_getter():
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# pop nodes # pop nodes
old_len = len(g.get_n_repr())
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0 assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))}) g.set_n_repr({'h' : th.zeros((10, D))})
# set partial nodes # set partial nodes
u = th.tensor([1, 3, 5]) u = th.tensor([1, 3, 5])
...@@ -81,8 +86,9 @@ def test_batch_setter_getter(): ...@@ -81,8 +86,9 @@ def test_batch_setter_getter():
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# pop edges # pop edges
old_len = len(g.get_e_repr())
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0 assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))}) g.set_e_repr({'l' : th.zeros((17, D))})
# set partial edges (many-many) # set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9]) u = th.tensor([0, 0, 2, 5, 9])
...@@ -203,14 +209,13 @@ def test_reduce_0deg(): ...@@ -203,14 +209,13 @@ def test_reduce_0deg():
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(src, edge):
return src return {'m' : src['h']}
def _reduce(node, msgs): def _reduce(node, msgs):
assert msgs is not None return {'h' : node['h'] + msgs['m'].sum(1)}
return node + msgs.sum(1)
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0)) assert th.allclose(new_repr[0], old_repr.sum(0))
...@@ -220,29 +225,30 @@ def test_pull_0deg(): ...@@ -220,29 +225,30 @@ def test_pull_0deg():
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(src, edge):
return src return {'m' : src['h']}
def _reduce(node, msgs): def _reduce(node, msgs):
assert msgs is not None return {'h' : msgs['m'].sum(1)}
return msgs.sum(1)
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1]) assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
def test_send_twice(): def _disabled_test_send_twice():
# TODO(minjie): please re-enable this unittest after the send code problem is fixed.
g = DGLGraph() g = DGLGraph()
g.add_nodes(3) g.add_nodes(3)
g.add_edge(0, 1) g.add_edge(0, 1)
...@@ -348,5 +354,4 @@ if __name__ == '__main__': ...@@ -348,5 +354,4 @@ if __name__ == '__main__':
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
test_send_twice()
test_send_multigraph() test_send_multigraph()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph, __REPR__
D = 32
reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(hu, e_uv):
assert len(hu.shape) == 2
assert hu.shape[1] == D
return hu
def reduce_func(hv, msgs):
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return hv + th.sum(msgs, 1)
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = Variable(th.randn(10, D), requires_grad=grad)
g.set_n_repr(col)
return g
def test_batch_setter_getter():
def _pfc(x):
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr(th.zeros((10, D)))
assert _pfc(g.get_n_repr()) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr()) == [0.] * 10
assert len(g.get_n_repr()) == 0
g.set_n_repr(th.zeros((10, D)))
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr(th.ones((3, D)), u)
assert _pfc(g.get_n_repr()) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)) == [1., 0., 1.]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g.set_e_repr(th.zeros((17, D)))
assert _pfc(g.get_e_repr()) == [0.] * 17
# pop edges
assert _pfc(g.pop_e_repr()) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr(th.zeros((17, D)))
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr(th.ones((5, D)), u, v)
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr(th.ones((3, D)), u, v)
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr(th.ones((3, D)), u, v)
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr(hh, v)
h2 = g.get_n_repr()
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.send(u, v)
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.send(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
# send_and_recv
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# pull
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.pull(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# push
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.push(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
# update_all
reduce_msg_shapes.clear()
g.update_all()
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
...@@ -18,8 +18,8 @@ def tree1(): ...@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4])) g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr(th.randn(4, 10)) g.set_e_repr({'h' : th.randn(4, 10)})
return g return g
def tree2(): def tree2():
...@@ -37,17 +37,17 @@ def tree2(): ...@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4) g.add_edge(0, 4)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4])) g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr(th.randn(4, 10)) g.set_e_repr({'h' : th.randn(4, 10)})
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
n1 = t1.get_n_repr() n1 = t1.get_n_repr()['h']
n2 = t2.get_n_repr() n2 = t2.get_n_repr()['h']
e1 = t1.get_e_repr() e1 = t1.get_e_repr()['h']
e2 = t2.get_e_repr() e2 = t2.get_e_repr()['h']
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10 assert bg.number_of_nodes() == 10
...@@ -57,10 +57,10 @@ def test_batch_unbatch(): ...@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4] assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg) tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr(), tt1.get_n_repr()) assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr(), tt1.get_e_repr()) assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr(), tt2.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), tt2.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h'])
def test_batch_unbatch1(): def test_batch_unbatch1():
t1 = tree1() t1 = tree1()
...@@ -74,20 +74,20 @@ def test_batch_unbatch1(): ...@@ -74,20 +74,20 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4] assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2) s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr(), s1.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), s1.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h'])
assert th.allclose(t1.get_n_repr(), s2.get_n_repr()) assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr(), s2.get_e_repr()) assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr(), s3.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), s3.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h'])
def test_batch_sendrecv(): def test_batch_sendrecv():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5] u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5] v = [1, 1, 4 + 5, 4 + 5]
...@@ -95,8 +95,8 @@ def test_batch_sendrecv(): ...@@ -95,8 +95,8 @@ def test_batch_sendrecv():
bg.recv(v) bg.recv(v)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()[1] == 7 assert t1.get_n_repr()['h'][1] == 7
assert t2.get_n_repr()[4] == 2 assert t2.get_n_repr()['h'][4] == 2
def test_batch_propagate(): def test_batch_propagate():
...@@ -104,8 +104,8 @@ def test_batch_propagate(): ...@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
# get leaves. # get leaves.
order = [] order = []
...@@ -123,45 +123,37 @@ def test_batch_propagate(): ...@@ -123,45 +123,37 @@ def test_batch_propagate():
bg.propagate(traverser=order) bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()[0] == 9 assert t1.get_n_repr()['h'][0] == 9
assert t2.get_n_repr()[1] == 5 assert t2.get_n_repr()['h'][1] == 5
def test_batched_edge_ordering(): def test_batched_edge_ordering():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
g1.add_nodes(6) g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1]) g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr(e1) g1.set_e_repr({'h' : e1})
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes(6) g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0]) g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr(e2) g2.set_e_repr({'h' : e2})
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.edge_id(4, 5)] r1 = g.get_e_repr()['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.edge_id(4, 5)] r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2) assert th.equal(r1, r2)
def test_batch_no_edge(): def test_batch_no_edge():
# FIXME: current impl cannot handle this case!!!
# comment out for now to test CI
return
"""
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
g1.add_nodes(6) g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1]) g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr(e1)
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes(6) g2.add_nodes(6)
g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0]) g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
e2 = th.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr(e2)
g3 = dgl.DGLGraph() g3 = dgl.DGLGraph()
g3.add_nodes(1) # no edges g3.add_nodes(1) # no edges
g = dgl.batch([g1, g3, g2]) # should not throw an error g = dgl.batch([g1, g3, g2]) # should not throw an error
"""
if __name__ == '__main__': if __name__ == '__main__':
test_batch_unbatch() test_batch_unbatch()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#!/bin/sh
# cpplint
echo 'Checking code style of C++ codes...'
python3 third_party/dmlc-core/scripts/lint.py dgl cpp include src
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment