"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "952b9131a21b03691c5086b0f32f11d927664755"
Unverified Commit 16561a2e authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Test] Add tests for TensorFlow (#1501)



* add test.

* move test code.

* remvoe unnecessary test.

* fix.

* turn on tests for TF.

* Revert "move test code."

This reverts commit e7b4f36395b2121a7be030bd4364a704d0e357bf.

* fix.

* fix.

* skip test for tensorflow.
Co-authored-by: default avatarChao Ma <mctt90@gmail.com>
parent 6ae440db
...@@ -140,9 +140,11 @@ def load_partition(conf_file, part_id): ...@@ -140,9 +140,11 @@ def load_partition(conf_file, part_id):
# TODO we need to fix this. DGL backend doesn't support boolean or byte. # TODO we need to fix this. DGL backend doesn't support boolean or byte.
# int64 is unnecessary. # int64 is unnecessary.
part_ids = F.zerocopy_from_numpy(node_map)[graph.ndata[NID]] node_map = F.zerocopy_from_numpy(node_map)
part_ids = F.gather_row(node_map, graph.ndata[NID])
graph.ndata['local_node'] = F.astype(part_ids == part_id, F.int64) graph.ndata['local_node'] = F.astype(part_ids == part_id, F.int64)
part_ids = F.zerocopy_from_numpy(edge_map)[graph.edata[EID]] edge_map = F.zerocopy_from_numpy(edge_map)
part_ids = F.gather_row(edge_map, graph.edata[EID])
graph.edata['local_edge'] = F.astype(part_ids == part_id, F.int64) graph.edata['local_edge'] = F.astype(part_ids == part_id, F.int64)
return graph, node_feats, edge_feats, meta return graph, node_feats, edge_feats, meta
...@@ -252,9 +254,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -252,9 +254,9 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
len(local_nodes), len(local_edges))) len(local_nodes), len(local_edges)))
tot_num_inner_edges += len(local_edges) tot_num_inner_edges += len(local_edges)
for name in g.ndata: for name in g.ndata:
node_feats[name] = g.ndata[name][local_nodes] node_feats[name] = F.gather_row(g.ndata[name], local_nodes)
for name in g.edata: for name in g.edata:
edge_feats[name] = g.edata[name][local_edges] edge_feats[name] = F.gather_row(g.edata[name], local_edges)
else: else:
for name in g.ndata: for name in g.ndata:
node_feats[name] = g.ndata[name] node_feats[name] = g.ndata[name]
......
...@@ -47,7 +47,7 @@ def test_partition(): ...@@ -47,7 +47,7 @@ def test_partition():
assert name in node_feats assert name in node_feats
assert node_feats[name].shape[0] == len(local_nodes) assert node_feats[name].shape[0] == len(local_nodes)
assert len(local_nodes) == len(node_feats[name]) assert len(local_nodes) == len(node_feats[name])
assert np.all(F.asnumpy(g.ndata[name][local_nodes]) == F.asnumpy(node_feats[name])) assert np.all(F.asnumpy(g.ndata[name])[local_nodes] == F.asnumpy(node_feats[name]))
assert len(edge_feats) == 0 assert len(edge_feats) == 0
......
...@@ -5,6 +5,7 @@ As a result, we decide to disable this test until we fixed the bug. ...@@ -5,6 +5,7 @@ As a result, we decide to disable this test until we fixed the bug.
""" """
import dgl import dgl
import sys import sys
import os
import random import random
import time import time
import numpy as np import numpy as np
...@@ -101,6 +102,7 @@ def server_func(num_workers, graph_name, server_init): ...@@ -101,6 +102,7 @@ def server_func(num_workers, graph_name, server_init):
server_init.value = 1 server_init.value = 1
g.run() g.run()
@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_init(): def test_init():
manager = Manager() manager = Manager()
return_dict = manager.dict() return_dict = manager.dict()
...@@ -168,6 +170,7 @@ def check_compute_func(worker_id, graph_name, return_dict): ...@@ -168,6 +170,7 @@ def check_compute_func(worker_id, graph_name, return_dict):
print(e, file=sys.stderr) print(e, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_compute(): def test_compute():
manager = Manager() manager = Manager()
return_dict = manager.dict() return_dict = manager.dict()
...@@ -215,6 +218,7 @@ def check_sync_barrier(worker_id, graph_name, return_dict): ...@@ -215,6 +218,7 @@ def check_sync_barrier(worker_id, graph_name, return_dict):
print(e, file=sys.stderr) print(e, file=sys.stderr)
traceback.print_exc() traceback.print_exc()
@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_sync_barrier(): def test_sync_barrier():
manager = Manager() manager = Manager()
return_dict = manager.dict() return_dict = manager.dict()
...@@ -275,6 +279,7 @@ def check_mem(gidx, cond_v, shared_v): ...@@ -275,6 +279,7 @@ def check_mem(gidx, cond_v, shared_v):
cond_v.notify() cond_v.notify()
cond_v.release() cond_v.release()
@unittest.skipIf(os.getenv('DGLBACKEND') == 'tensorflow', reason="skip for tensorflow")
def test_copy_shared_mem(): def test_copy_shared_mem():
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
gidx = dgl.graph_index.create_graph_index(csr, True) gidx = dgl.graph_index.create_graph_index(csr, True)
......
...@@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra ...@@ -37,6 +37,6 @@ python3 -m pytest -v --junitxml=pytest_gindex.xml tests/graph_index || fail "gra
python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific" python3 -m pytest -v --junitxml=pytest_backend.xml tests/$DGLBACKEND || fail "backend-specific"
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
if [ $2 != "gpu" ] && [ $1 != "tensorflow" ]; then if [ $2 != "gpu" ]; then
python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed" python3 -m pytest -v --junitxml=pytest_distributed.xml tests/distributed || fail "distributed"
fi fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment