"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "d163c2c16629e3da41dc5d7b8dbb76b8fc6249db"
Commit 596ca471 authored by GaiYu0's avatar GaiYu0
Browse files

Merge branch 'cpp' of https://github.com/jermainewang/dgl into line-graph

parents 52ed6a45 72f63455
...@@ -5,13 +5,9 @@ from dgl.graph import DGLGraph ...@@ -5,13 +5,9 @@ from dgl.graph import DGLGraph
D = 5 D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
for i in range(10): g.add_nodes(10)
g.add_node(i) # 10 nodes.
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
...@@ -29,17 +25,19 @@ def test_basics(): ...@@ -29,17 +25,19 @@ def test_basics():
h = g.get_n_repr()['h'] h = g.get_n_repr()['h']
l = g.get_e_repr()['l'] l = g.get_e_repr()['l']
nid = [0, 2, 3, 6, 7, 9] nid = [0, 2, 3, 6, 7, 9]
eid = [2, 3, 4, 5, 10, 11, 12, 13, 16]
sg = g.subgraph(nid) sg = g.subgraph(nid)
eid = {2, 3, 4, 5, 10, 11, 12, 13, 16}
assert set(sg.parent_eid.numpy()) == eid
eid = sg.parent_eid
# the subgraph is empty initially # the subgraph is empty initially
assert len(sg.get_n_repr()) == 0 assert len(sg.get_n_repr()) == 0
assert len(sg.get_e_repr()) == 0 assert len(sg.get_e_repr()) == 0
# the data is copied after explict copy from # the data is copied after explict copy from
sg.copy_from(g) sg.copy_from_parent()
assert len(sg.get_n_repr()) == 1 assert len(sg.get_n_repr()) == 1
assert len(sg.get_e_repr()) == 1 assert len(sg.get_e_repr()) == 1
sh = sg.get_n_repr()['h'] sh = sg.get_n_repr()['h']
assert check_eq(h[nid], sh) assert th.allclose(h[nid], sh)
''' '''
s, d, eid s, d, eid
0, 1, 0 0, 1, 0
...@@ -60,11 +58,11 @@ def test_basics(): ...@@ -60,11 +58,11 @@ def test_basics():
8, 9, 15 3 8, 9, 15 3
9, 0, 16 1 9, 0, 16 1
''' '''
assert check_eq(l[eid], sg.get_e_repr()['l']) assert th.allclose(l[eid], sg.get_e_repr()['l'])
# update the node/edge features on the subgraph should NOT # update the node/edge features on the subgraph should NOT
# reflect to the parent graph. # reflect to the parent graph.
sg.set_n_repr({'h' : th.zeros((6, D))}) sg.set_n_repr({'h' : th.zeros((6, D))})
assert check_eq(h, g.get_n_repr()['h']) assert th.allclose(h, g.get_n_repr()['h'])
def test_merge(): def test_merge():
g = generate_graph() g = generate_graph()
...@@ -85,10 +83,10 @@ def test_merge(): ...@@ -85,10 +83,10 @@ def test_merge():
h = g.get_n_repr()['h'][:,0] h = g.get_n_repr()['h'][:,0]
l = g.get_e_repr()['l'][:,0] l = g.get_e_repr()['l'][:,0]
assert check_eq(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.])) assert th.allclose(h, th.tensor([3., 0., 3., 3., 2., 0., 1., 1., 0., 1.]))
assert check_eq(l, assert th.allclose(l,
th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.])) th.tensor([0., 0., 1., 1., 1., 1., 0., 0., 0., 3., 1., 4., 1., 4., 0., 3., 1.]))
if __name__ == '__main__': if __name__ == '__main__':
test_basics() test_basics()
test_merge() #test_merge()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment