".github/vscode:/vscode.git/clone" did not exist on "b63c956860373ef169cfb24c2088a9b173a72bfd"
Unverified Commit 566910d8 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update gpu sampling tests of `sample_beighbors`. (#6892)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent c09d4660
......@@ -1590,11 +1590,9 @@ def test_csc_sampling_graph_to_pinned_memory():
assert graph.edge_attributes[key].is_pinned()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors_homo():
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("is_pinned", [False, True])
def test_sample_neighbors_homo(labor, is_pinned):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -1611,10 +1609,16 @@ def test_sample_neighbors_homo():
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
if F._default_context_str == "gpu":
if is_pinned:
graph.pin_memory_()
else:
graph = graph.to(F.ctx())
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([2]))
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# Verify in subgraph.
sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
......@@ -1628,7 +1632,7 @@ def test_sample_neighbors_homo():
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Heterogenous sampling on gpu is not supported yet.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
......@@ -1715,7 +1719,7 @@ def test_sample_neighbors_hetero(labor):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Heterogenous sampling on gpu is not supported yet.",
)
@pytest.mark.parametrize(
"fanouts, expected_sampled_num1, expected_sampled_num2",
......@@ -1789,7 +1793,7 @@ def test_sample_neighbors_fanouts(
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize(
"replace, expected_sampled_num1, expected_sampled_num2",
......@@ -1846,12 +1850,9 @@ def test_sample_neighbors_replace(
assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_homo(labor):
@pytest.mark.parametrize("is_pinned", [False, True])
def test_sample_neighbors_return_eids_homo(labor, is_pinned):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -1873,24 +1874,32 @@ def test_sample_neighbors_return_eids_homo(labor):
graph = gb.fused_csc_sampling_graph(
indptr, indices, edge_attributes=edge_attributes
)
if F._default_context_str == "gpu":
if is_pinned:
graph.pin_memory_()
else:
graph = graph.to(F.ctx())
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([-1]))
# Verify in subgraph.
expected_reverse_edge_ids = edge_attributes[gb.ORIGINAL_EDGE_ID][
torch.tensor([3, 4, 7, 8, 9, 10, 11])
]
assert torch.equal(expected_reverse_edge_ids, subgraph.original_edge_ids)
].to(F.ctx())
assert torch.equal(
torch.sort(expected_reverse_edge_ids)[0],
torch.sort(subgraph.original_edge_ids)[0],
)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Heterogenous sampling on gpu is not supported yet.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_return_eids_hetero(labor):
......@@ -1950,7 +1959,7 @@ def test_sample_neighbors_return_eids_hetero(labor):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
......@@ -2004,7 +2013,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
......@@ -2049,7 +2058,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
......@@ -2134,7 +2143,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment