Unverified Commit 815b88a6 authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Dist][Test] Add blacklisting mechanism + misc fixes (#4950)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
parent 226d1159
...@@ -124,21 +124,6 @@ def test_dist_graph(g): ...@@ -124,21 +124,6 @@ def test_dist_graph(g):
########################################## ##########################################
########### DistGraphServices ########### ########### DistGraphServices ###########
########################################## ##########################################
nids = F.arange(0, 16)
# Test in_degrees
orig_in_degrees = g.ndata['in_degrees']
local_in_degrees = g.in_degrees(nids)
F.allclose(local_in_degrees, orig_in_degrees[nids])
# Test out_degrees
orig_out_degrees = g.ndata['out_degrees']
local_out_degrees = g.out_degrees(nids)
F.allclose(local_out_degrees, orig_out_degrees[nids])
find_edges_test(g)
edge_subgraph_test(g)
sample_neighbors_test(g)
def find_edges_test(g, orig_nid_map): def find_edges_test(g, orig_nid_map):
etypes = g.canonical_etypes etypes = g.canonical_etypes
...@@ -197,6 +182,20 @@ def sample_neighbors_test(g): ...@@ -197,6 +182,20 @@ def sample_neighbors_test(g):
sample_neighbors_with_args(g, size=2**12, fanout=1) sample_neighbors_with_args(g, size=2**12, fanout=1)
def test_dist_graph_services(g): def test_dist_graph_services(g):
# in_degrees and out_degrees does not support heterograph
if len(g.etypes) == 1:
nids = F.arange(0, 128)
# Test in_degrees
orig_in_degrees = g.ndata['in_degrees']
local_in_degrees = g.in_degrees(nids)
F.allclose(local_in_degrees, orig_in_degrees[nids])
# Test out_degrees
orig_out_degrees = g.ndata['out_degrees']
local_out_degrees = g.out_degrees(nids)
F.allclose(local_out_degrees, orig_out_degrees[nids])
num_nodes = {ntype : g.num_nodes(ntype) for ntype in g.ntypes} num_nodes = {ntype : g.num_nodes(ntype) for ntype in g.ntypes}
orig_nid_map = dict() orig_nid_map = dict()
...@@ -425,7 +424,7 @@ class NeighborSampler(object): ...@@ -425,7 +424,7 @@ class NeighborSampler(object):
def distdataloader_test(g, batch_size, drop_last, shuffle): def distdataloader_test(g, batch_size, drop_last, shuffle):
# We sample only a subset to minimize the test runtime # We sample only a subset to minimize the test runtime
num_nodes_to_sample = g.num_nodes() * 0.05 num_nodes_to_sample = int(g.num_nodes() * 0.05)
# To make sure that drop_last is tested # To make sure that drop_last is tested
if num_nodes_to_sample % batch_size == 0: if num_nodes_to_sample % batch_size == 0:
num_nodes_to_sample -= 1 num_nodes_to_sample -= 1
...@@ -490,7 +489,7 @@ def distdataloader_test(g, batch_size, drop_last, shuffle): ...@@ -490,7 +489,7 @@ def distdataloader_test(g, batch_size, drop_last, shuffle):
def distnodedataloader_test(g, batch_size, drop_last, shuffle, def distnodedataloader_test(g, batch_size, drop_last, shuffle,
num_workers, orig_nid_map, orig_uv_map): num_workers, orig_nid_map, orig_uv_map):
# We sample only a subset to minimize the test runtime # We sample only a subset to minimize the test runtime
num_nodes_to_sample = g.num_nodes(g.ntypes[-1]) * 0.05 num_nodes_to_sample = int(g.num_nodes(g.ntypes[-1]) * 0.05)
# To make sure that drop_last is tested # To make sure that drop_last is tested
if num_nodes_to_sample % batch_size == 0: if num_nodes_to_sample % batch_size == 0:
num_nodes_to_sample -= 1 num_nodes_to_sample -= 1
...@@ -544,7 +543,7 @@ def distnodedataloader_test(g, batch_size, drop_last, shuffle, ...@@ -544,7 +543,7 @@ def distnodedataloader_test(g, batch_size, drop_last, shuffle,
def distedgedataloader_test(g, batch_size, drop_last, shuffle, def distedgedataloader_test(g, batch_size, drop_last, shuffle,
num_workers, orig_nid_map, orig_uv_map, num_negs): num_workers, orig_nid_map, orig_uv_map, num_negs):
# We sample only a subset to minimize the test runtime # We sample only a subset to minimize the test runtime
num_edges_to_sample = g.num_edges(g.etypes[-1]) * 0.05 num_edges_to_sample = int(g.num_edges(g.etypes[-1]) * 0.05)
# To make sure that drop_last is tested # To make sure that drop_last is tested
if num_edges_to_sample % batch_size == 0: if num_edges_to_sample % batch_size == 0:
num_edges_to_sample -= 1 num_edges_to_sample -= 1
...@@ -603,11 +602,6 @@ def distedgedataloader_test(g, batch_size, drop_last, shuffle, ...@@ -603,11 +602,6 @@ def distedgedataloader_test(g, batch_size, drop_last, shuffle,
def multi_distdataloader_test(g, dataloader_class): def multi_distdataloader_test(g, dataloader_class):
total_num_items = g.num_nodes(g.ntypes[-1]) if "Node" in dataloader_class.__name__ else g.num_edges(g.etypes[-1]) total_num_items = g.num_nodes(g.ntypes[-1]) if "Node" in dataloader_class.__name__ else g.num_edges(g.etypes[-1])
# We sample only a subset to minimize the test runtime
num_items_to_sample = total_num_items * 0.05
# To make sure that drop_last is tested
if num_items_to_sample % batch_size == 0:
num_items_to_sample -= 1
num_dataloaders=4 num_dataloaders=4
batch_size=32 batch_size=32
...@@ -615,6 +609,12 @@ def multi_distdataloader_test(g, dataloader_class): ...@@ -615,6 +609,12 @@ def multi_distdataloader_test(g, dataloader_class):
dataloaders = [] dataloaders = []
dl_iters = [] dl_iters = []
# We sample only a subset to minimize the test runtime
num_items_to_sample = int(total_num_items * 0.05)
# To make sure that drop_last is tested
if num_items_to_sample % batch_size == 0:
num_items_to_sample -= 1
if len(g.ntypes) == 1: if len(g.ntypes) == 1:
train_ids = F.arange(0, num_items_to_sample) train_ids = F.arange(0, num_items_to_sample)
else: else:
...@@ -695,12 +695,23 @@ elif mode == "client": ...@@ -695,12 +695,23 @@ elif mode == "client":
"DistDataLoader": test_dist_dataloader, "DistDataLoader": test_dist_dataloader,
} }
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "") targets = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
if target not in target_func_map: targets = targets.replace(' ', '').split(',') if targets else []
blacklist = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST", "")
blacklist = blacklist.replace(' ', '').split(',') if blacklist else []
for to_bl in blacklist:
target_func_map.pop(to_bl, None)
if not targets:
for test_func in target_func_map.values(): for test_func in target_func_map.values():
test_func(g) test_func(g)
else: else:
target_func_map[target](g) for target in targets:
if target in target_func_map:
target_func_map[target](g)
else:
print(f"Tests not implemented for target '{target}'")
else: else:
exit(1) exit(1)
...@@ -14,6 +14,7 @@ from dgl.distributed import partition_graph ...@@ -14,6 +14,7 @@ from dgl.distributed import partition_graph
graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph") graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph")
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "") target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
blacklist = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST", "")
shared_workspace = os.environ.get('DIST_DGL_TEST_WORKSPACE', '/shared_workspace/dgl_dist_tensor_test/') shared_workspace = os.environ.get('DIST_DGL_TEST_WORKSPACE', '/shared_workspace/dgl_dist_tensor_test/')
...@@ -66,7 +67,7 @@ def create_graph(num_part, dist_graph_path, hetero): ...@@ -66,7 +67,7 @@ def create_graph(num_part, dist_graph_path, hetero):
for _, etype, _ in etypes: for _, etype, _ in etypes:
edge_u, edge_v = g.find_edges(F.arange(0, g.number_of_edges(etype))) edge_u, edge_v = g.find_edges(F.arange(0, g.number_of_edges(etype)), etype=etype)
g.edges[etype].data["edge_u"] = edge_u g.edges[etype].data["edge_u"] = edge_u
g.edges[etype].data["edge_v"] = edge_v g.edges[etype].data["edge_v"] = edge_v
...@@ -136,6 +137,7 @@ def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem): ...@@ -136,6 +137,7 @@ def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem):
cmd_envs = ( cmd_envs = (
base_envs + f"DIST_DGL_TEST_PART_ID={part_id} " base_envs + f"DIST_DGL_TEST_PART_ID={part_id} "
f"DIST_DGL_TEST_OBJECT_TYPE={target} " f"DIST_DGL_TEST_OBJECT_TYPE={target} "
f"DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST={blacklist} "
f"DIST_DGL_TEST_MODE=client " f"DIST_DGL_TEST_MODE=client "
) )
procs.append( procs.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment