Unverified Commit c42eac71 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

Fix batched graph edge order bug and other fixes (#50)

* fix dgl.batch edge ordering bug

* add graph batching test cases

* fix partial spmv ctx.

* add dataset generating for dgmg
parent 6105e441
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import argparse import argparse
from util import DataLoader, elapsed from util import DataLoader, elapsed, generate_dataset
import time import time
class MLP(nn.Module): class MLP(nn.Module):
...@@ -246,9 +246,16 @@ if __name__ == '__main__': ...@@ -246,9 +246,16 @@ if __name__ == '__main__':
help="number of hidden gcn layers") help="number of hidden gcn layers")
parser.add_argument("--dataset", type=str, default='samples.p', parser.add_argument("--dataset", type=str, default='samples.p',
help="dataset pickle file") help="dataset pickle file")
parser.add_argument("--gen-dataset", type=str, default=None,
help="parameters to generate B-A graph datasets. Format: <#node>,<#edge>,<#sample>")
parser.add_argument("--batch-size", type=int, default=32, parser.add_argument("--batch-size", type=int, default=32,
help="batch size") help="batch size")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
# generate dataset if needed
if args.gen_dataset is not None:
n_node, n_edge, n_sample = map(int, args.gen_dataset.split(','))
generate_dataset(n_node, n_edge, n_sample, args.dataset)
main(args) main(args)
...@@ -15,16 +15,13 @@ def convert_graph_to_ordering(g): ...@@ -15,16 +15,13 @@ def convert_graph_to_ordering(g):
ordering.append((m, n)) ordering.append((m, n))
return ordering return ordering
def generate_dataset(): def generate_dataset(n, m, n_samples, fname):
n = 15
m = 2
n_samples = 1024
samples = [] samples = []
for _ in range(n_samples): for _ in range(n_samples):
g = nx.barabasi_albert_graph(n, m) g = nx.barabasi_albert_graph(n, m)
samples.append(convert_graph_to_ordering(g)) samples.append(convert_graph_to_ordering(g))
with open('samples.p', 'wb') as f: with open(fname, 'wb') as f:
pickle.dump(samples, f) pickle.dump(samples, f)
class DataLoader(object): class DataLoader(object):
...@@ -153,4 +150,8 @@ def elapsed(msg, start, end): ...@@ -153,4 +150,8 @@ def elapsed(msg, start, end):
print("{}: {} ms".format(msg, int((end-start)*1000))) print("{}: {} ms".format(msg, int((end-start)*1000)))
if __name__ == '__main__': if __name__ == '__main__':
generate_dataset() n = 15
m = 2
n_samples = 1024
fname ='samples.p'
generate_dataset(n, m, n_samples, fname)
...@@ -22,7 +22,7 @@ class BatchedDGLGraph(DGLGraph): ...@@ -22,7 +22,7 @@ class BatchedDGLGraph(DGLGraph):
self.add_nodes_from(range(self.node_offset[-1])) self.add_nodes_from(range(self.node_offset[-1]))
# in-order add relabeled edges # in-order add relabeled edges
self.new_edge_list = [np.array(g.edges) + offset self.new_edge_list = [np.array(g.edge_list) + offset
for g, offset in zip(self.graph_list, self.node_offset[:-1])] for g, offset in zip(self.graph_list, self.node_offset[:-1])]
self.new_edges = np.concatenate(self.new_edge_list) self.new_edges = np.concatenate(self.new_edge_list)
self.add_edges_from(self.new_edges) self.add_edges_from(self.new_edges)
......
...@@ -710,11 +710,12 @@ class DGLGraph(DiGraph): ...@@ -710,11 +710,12 @@ class DGLGraph(DiGraph):
m = len(new2old) m = len(new2old)
# TODO(minjie): context # TODO(minjie): context
adjmat = F.sparse_tensor(idx, dat, [m, n]) adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs # TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {} reduced_msgs = {}
for key in self._node_frame.schemes: for key in self._node_frame.schemes:
col = self._node_frame[key] col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col) reduced_msgs[key] = F.spmm(ctx_adjmat.get(F.get_context(col)), col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs: if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__] reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr(new2old) node_repr = self.get_n_repr(new2old)
......
...@@ -23,6 +23,7 @@ def tree1(): ...@@ -23,6 +23,7 @@ def tree1():
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
return g return g
def tree2(): def tree2():
...@@ -45,19 +46,24 @@ def tree2(): ...@@ -45,19 +46,24 @@ def tree2():
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4])) g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
g.set_e_repr(torch.randn(4, 10))
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
f1 = t1.get_n_repr() n1 = t1.get_n_repr()
f2 = t2.get_n_repr() n2 = t2.get_n_repr()
e1 = t1.get_e_repr()
e2 = t2.get_e_repr()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
dgl.unbatch(bg) dgl.unbatch(bg)
assert(f1.equal(t1.get_n_repr())) assert(n1.equal(t1.get_n_repr()))
assert(f2.equal(t2.get_n_repr())) assert(n2.equal(t2.get_n_repr()))
assert(e1.equal(t1.get_e_repr()))
assert(e2.equal(t2.get_e_repr()))
def test_batch_sendrecv(): def test_batch_sendrecv():
...@@ -120,8 +126,25 @@ def test_batch_propagate(): ...@@ -120,8 +126,25 @@ def test_batch_propagate():
assert t1.get_n_repr()[0] == 9 assert t1.get_n_repr()[0] == 9
assert t2.get_n_repr()[1] == 5 assert t2.get_n_repr()[1] == 5
def test_batched_edge_ordering():
g1 = dgl.DGLGraph()
g1.add_nodes_from([0,1,2, 3, 4, 5])
g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
g1.edge_list
e1 = torch.randn(5, 10)
g1.set_e_repr(e1)
g2 = dgl.DGLGraph()
g2.add_nodes_from([0, 1, 2, 3, 4, 5])
g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
e2 = torch.randn(6, 10)
g2.set_e_repr(e2)
g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
assert torch.equal(r1, r2)
if __name__ == '__main__': if __name__ == '__main__':
test_batch_unbatch() test_batch_unbatch()
test_batched_edge_ordering()
test_batch_sendrecv() test_batch_sendrecv()
test_batch_propagate() test_batch_propagate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment