Unverified Commit 11e42d10 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Batching semantics and naive frame storage (#31)

* batch message_func, reduce_func and update_func

Conflicts:
	python/dgl/backend/pytorch.py

* test cases for batching

Conflicts:
	python/dgl/graph.py

* resolve conflicts

* setter/getter

Conflicts:
	python/dgl/graph.py

* test setter/getter

Conflicts:
	python/dgl/graph.py

* merge DGLGraph and DGLBGraph

Conflicts:
	python/dgl/graph.py

Conflicts:
	python/dgl/graph.py

* batchability test

Conflicts:
	python/dgl/graph.py

Conflicts:
	python/dgl/graph.py

* New interface (draft)

Conflicts:
	_reference/gat_mx.py
	_reference/molecular-gcn.py
	_reference/molecular-gcn_mx.py
	_reference/multi-gcn.py
	_reference/multi-gcn_mx.py
	_reference/mx.py
	python/dgl/graph.py

* Batch operations on graph

Conflicts:
	.gitignore
	python/dgl/backend/__init__.py
	python/dgl/backend/numpy.py
	python/dgl/graph.py

* sendto

* storage

* NodeDict

* DGLFrame/DGLArray

* scaffold code for graph.py

* clean up files; initial frame code

* basic frame tests using pytorch

* frame autograd test passed

* fix non-batched tests

* initial code for cached graph; tested

* batch sendto

* batch recv

* update routines

* update all

* anonymous repr batching

* specialize test

* igraph dep

* fix

* fix

* fix

* fix

* fix

* clean some files

* batch setter and getter

* fix utests
parent 8361bbbe
import torch as th
from dgl.graph import DGLGraph
D = 5
reduce_msg_shapes = set()
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return th.sum(msgs, 1)
def update_func(node, accum):
assert node['h'].shape == accum.shape
return {'h' : node['h'] + accum}
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = th.randn(10, D)
g.set_n_repr({'h' : col})
return g
def test_batch_setter_getter():
def _pfc(x):
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
# set partial nodes (many-many)
# TODO(minjie): following case will fail at the moment as CachedGraph
# does not maintain edge addition order.
#u = th.tensor([0, 0, 2, 5, 9])
#v = th.tensor([1, 3, 9, 9, 0])
#g.set_e_repr({'l' : th.ones((5, D))}, u, v)
#truth = [0.] * 17
#truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
#assert _pfc(g.get_e_repr()['l']) == truth
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
g.register_message_func(_fmsg, batchable=True)
# many-many sendto
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# one-many sendto
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# many-one sendto
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.sendto(u, v)
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
# update_by_edge
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_by_edge(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_to
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.update_to(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_from
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.update_from(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
# update_all
reduce_msg_shapes.clear()
g.update_all()
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_send()
test_batch_recv()
test_update_routines()
import torch as th
from dgl.graph import DGLGraph, __REPR__
D = 32
reduce_msg_shapes = set()
def message_func(hu, e_uv):
assert len(hu.shape) == 2
assert hu.shape[1] == D
return hu
def reduce_func(hv, msgs):
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return th.sum(msgs, 1)
def update_func(hv, accum):
assert hv.shape == accum.shape
return hv + accum
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
col = th.randn(10, D)
g.set_n_repr(col)
return g
def test_batch_setter_getter():
def _pfc(x):
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr(th.zeros((10, D)))
assert _pfc(g.get_n_repr()) == [0.] * 10
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr(th.ones((3, D)), u)
assert _pfc(g.get_n_repr()) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)) == [1., 0., 1.]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g.set_e_repr(th.zeros((17, D)))
assert _pfc(g.get_e_repr()) == [0.] * 17
# set partial nodes (many-many)
# TODO(minjie): following case will fail at the moment as CachedGraph
# does not maintain edge addition order.
#u = th.tensor([0, 0, 2, 5, 9])
#v = th.tensor([1, 3, 9, 9, 0])
#g.set_e_repr({'l' : th.ones((5, D))}, u, v)
#truth = [0.] * 17
#truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
#assert _pfc(g.get_e_repr()['l']) == truth
def test_batch_send():
g = generate_graph()
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg, batchable=True)
# many-many sendto
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# one-many sendto
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.sendto(u, v)
# many-one sendto
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.sendto(u, v)
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.sendto(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.register_update_func(update_func, batchable=True)
# update_by_edge
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_by_edge(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_to
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.update_to(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# update_from
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.update_from(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
# update_all
reduce_msg_shapes.clear()
g.update_all()
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
if __name__ == '__main__':
test_batch_send()
test_batch_recv()
test_update_routines()
import torch as th
import numpy as np
import networkx as nx
from dgl import DGLGraph
from dgl.cached_graph import *
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def test_basics():
g = DGLGraph()
g.add_edge(0, 1)
g.add_edge(1, 2)
g.add_edge(1, 3)
g.add_edge(2, 4)
g.add_edge(2, 5)
cg = create_cached_graph(g)
u = th.tensor([0, 1, 1, 2, 2])
v = th.tensor([1, 2, 3, 4, 5])
check_eq(cg.get_edge_id(u, v), th.tensor([0, 1, 2, 3, 4]))
cg.add_edges(0, 2)
assert cg.get_edge_id(0, 2) == 5
query = th.tensor([1, 2])
s, d = cg.in_edges(query)
check_eq(s, th.tensor([0, 0, 1]))
check_eq(d, th.tensor([1, 2, 2]))
s, d = cg.out_edges(query)
check_eq(s, th.tensor([1, 1, 2, 2]))
check_eq(d, th.tensor([2, 3, 4, 5]))
if __name__ == '__main__':
test_basics()
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.frame import Frame
N = 10
D = 32
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def create_test_data(grad=False):
c1 = Variable(th.randn(N, D), requires_grad=grad)
c2 = Variable(th.randn(N, D), requires_grad=grad)
c3 = Variable(th.randn(N, D), requires_grad=grad)
return {'a1' : c1, 'a2' : c2, 'a3' : c3}
def test_create():
data = create_test_data()
f1 = Frame()
for k, v in data.items():
f1.add_column(k, v)
assert f1.schemes == set(data.keys())
assert f1.num_columns == 3
assert f1.num_rows == N
f2 = Frame(data)
assert f2.schemes == set(data.keys())
assert f2.num_columns == 3
assert f2.num_rows == N
f1.clear()
assert len(f1.schemes) == 0
assert f1.num_rows == 0
def test_col_getter_setter():
data = create_test_data()
f = Frame(data)
check_eq(f['a1'], data['a1'])
f['a1'] = data['a2']
check_eq(f['a2'], data['a2'])
def test_row_getter_setter():
data = create_test_data()
f = Frame(data)
# getter
# test non-duplicate keys
rowid = th.tensor([0, 2])
rows = f[rowid]
for k, v in rows:
assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid])
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid]
for k, v in rows:
assert v.shape == (len(rowid), D)
check_eq(v, data[k][rowid])
# setter
rowid = th.tensor([0, 2, 4])
vals = {'a1' : th.zeros((len(rowid), D)),
'a2' : th.zeros((len(rowid), D)),
'a3' : th.zeros((len(rowid), D)),
}
f[rowid] = vals
for k, v in f[rowid]:
check_eq(v, th.zeros((len(rowid), D)))
def test_row_getter_setter_grad():
data = create_test_data(grad=True)
f = Frame(data)
# getter
c1 = f['a1']
# test non-duplicate keys
rowid = th.tensor([0, 2])
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
c1.grad.data.zero_()
# test duplicate keys
rowid = th.tensor([8, 2, 2, 1])
rows = f[rowid]
rows['a1'].backward(th.ones((len(rowid), D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
c1.grad.data.zero_()
# setter
c1 = f['a1']
rowid = th.tensor([0, 2, 4])
vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
}
f[rowid] = vals
c11 = f['a1']
c11.backward(th.ones((N, D)))
check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
assert vals['a2'].grad is None
def test_append():
data = create_test_data()
f1 = Frame()
f2 = Frame(data)
f1.append(data)
assert f1.num_rows == N
f1.append(f2)
assert f1.num_rows == 2 * N
c1 = f1['a1']
assert c1.shape == (2 * N, D)
truth = th.cat([data['a1'], data['a1']])
check_eq(truth, c1)
if __name__ == '__main__':
test_create()
test_col_getter_setter()
test_append()
test_row_getter_setter()
test_row_getter_setter_grad()
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
def message_func(hu, e_uv):
return hu + e_uv
def update_func(h, accum):
......@@ -10,20 +10,17 @@ def update_func(h, accum):
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
g.add_edge(0, i, __REPR__=1)
g.add_edge(i, 9, __REPR__=1)
# add a back flow from 9 to 0
g.add_edge(9, 0)
return g
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
......@@ -41,7 +38,7 @@ def test_sendrecv():
g.recv(9)
check(g, [1, 4, 3, 4, 5, 6, 7, 8, 9, 25])
def message_func_hybrid(src, dst, edge):
def message_func_hybrid(src, edge):
return src[__REPR__] + edge
def update_func_hybrid(node, accum):
......
from dgl.graph import DGLGraph
def message_func(src, dst, edge):
def message_func(src, edge):
return src['h']
def update_func(node, accum):
......
from dgl import DGLGraph
from dgl.graph import __REPR__
def message_func(hu, hv, e_uv):
def message_func(hu, e_uv):
return hu
def message_not_called(hu, hv, e_uv):
def message_not_called(hu, e_uv):
assert False
return hu
......@@ -21,21 +21,18 @@ def update_func(h, accum):
return h + accum
def check(g, h):
nh = [str(g.get_n_repr(i)) for i in range(10)]
nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
h = [str(x) for x in h]
assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))
def generate_graph():
g = DGLGraph()
for i in range(10):
g.add_node(i) # 10 nodes.
g.set_n_repr(i, i+1)
g.add_node(i, __REPR__=i+1) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.set_e_repr(0, i, 1)
g.add_edge(i, 9)
g.set_e_repr(i, 9, 1)
return g
def test_no_msg_update():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment