"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ccb7f45a3570b2175d8e8def66629528d557da3c"
Unverified Commit fbb26ee5 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] fix crash issue if empty test set (#5653)

parent 2eb3f08c
...@@ -1500,7 +1500,7 @@ def _split_even_to_part(partition_book, elements): ...@@ -1500,7 +1500,7 @@ def _split_even_to_part(partition_book, elements):
x = y = 0 x = y = 0
num_elements = len(elements) num_elements = len(elements)
block_size = num_elements // partition_book.num_partitions() block_size = num_elements // partition_book.num_partitions()
part_eles = None part_eles = F.tensor([], dtype=elements.dtype)
# compute the nonzero tensor of each partition instead of whole tensor to save memory # compute the nonzero tensor of each partition instead of whole tensor to save memory
for idx in range(0, num_elements, block_size): for idx in range(0, num_elements, block_size):
nonzero_block = F.nonzero_1d( nonzero_block = F.nonzero_1d(
...@@ -1512,10 +1512,7 @@ def _split_even_to_part(partition_book, elements): ...@@ -1512,10 +1512,7 @@ def _split_even_to_part(partition_book, elements):
start = max(x, left) - x start = max(x, left) - x
end = min(y, right) - x end = min(y, right) - x
tmp = nonzero_block[start:end] + idx tmp = nonzero_block[start:end] + idx
if part_eles is None: part_eles = F.cat((part_eles, tmp), 0)
part_eles = tmp
else:
part_eles = F.cat((part_eles, tmp), 0)
elif x >= right: elif x >= right:
break break
......
...@@ -1080,7 +1080,8 @@ def test_standalone_node_emb(): ...@@ -1080,7 +1080,8 @@ def test_standalone_node_emb():
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("hetero", [True, False]) @pytest.mark.parametrize("hetero", [True, False])
def test_split(hetero): @pytest.mark.parametrize("empty_mask", [True, False])
def test_split(hetero, empty_mask):
if hetero: if hetero:
g = create_random_hetero() g = create_random_hetero()
ntype = "n1" ntype = "n1"
...@@ -1100,8 +1101,9 @@ def test_split(hetero): ...@@ -1100,8 +1101,9 @@ def test_split(hetero):
part_method="metis", part_method="metis",
) )
node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > 30 mask_thd = 100 if empty_mask else 30
edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > 30 node_mask = np.random.randint(0, 100, size=g.num_nodes(ntype)) > mask_thd
edge_mask = np.random.randint(0, 100, size=g.num_edges(etype)) > mask_thd
selected_nodes = np.nonzero(node_mask)[0] selected_nodes = np.nonzero(node_mask)[0]
selected_edges = np.nonzero(edge_mask)[0] selected_edges = np.nonzero(edge_mask)[0]
...@@ -1173,7 +1175,8 @@ def test_split(hetero): ...@@ -1173,7 +1175,8 @@ def test_split(hetero):
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet") @unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
def test_split_even(): @pytest.mark.parametrize("empty_mask", [True, False])
def test_split_even(empty_mask):
g = create_random_graph(10000) g = create_random_graph(10000)
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
...@@ -1186,10 +1189,9 @@ def test_split_even(): ...@@ -1186,10 +1189,9 @@ def test_split_even():
part_method="metis", part_method="metis",
) )
node_mask = np.random.randint(0, 100, size=g.num_nodes()) > 30 mask_thd = 100 if empty_mask else 30
edge_mask = np.random.randint(0, 100, size=g.num_edges()) > 30 node_mask = np.random.randint(0, 100, size=g.num_nodes()) > mask_thd
selected_nodes = np.nonzero(node_mask)[0] edge_mask = np.random.randint(0, 100, size=g.num_edges()) > mask_thd
selected_edges = np.nonzero(edge_mask)[0]
all_nodes1 = [] all_nodes1 = []
all_nodes2 = [] all_nodes2 = []
all_edges1 = [] all_edges1 = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment