Unverified Commit c09d4660 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add gpu tests to `NeighborSampler`. (#6880)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
Co-authored-by: default avatarMuhammed Fatih BALIN <m.f.balin@gmail.com>
parent 95d62394
...@@ -14,11 +14,12 @@ from torchdata.datapipes.iter import Mapper ...@@ -14,11 +14,12 @@ from torchdata.datapipes.iter import Mapper
from . import gb_test_utils from . import gb_test_utils
# Skip all tests on GPU. # Skip all tests on GPU when sampling with TemporalNeighborSampler.
pytestmark = pytest.mark.skipif( def _check_sampler_type(sampler_type):
F._default_context_str != "cpu", if F._default_context_str != "cpu" and sampler_type == SamplerType.Temporal:
reason="GraphBolt sampling tests are only supported on CPU.", pytest.skip(
) "TemporalNeighborSampler sampling tests are only supported on CPU."
)
class SamplerType(Enum): class SamplerType(Enum):
...@@ -108,6 +109,7 @@ def test_NeighborSampler_fanouts(labor): ...@@ -108,6 +109,7 @@ def test_NeighborSampler_fanouts(labor):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Node(sampler_type): def test_SubgraphSampler_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
...@@ -139,6 +141,7 @@ def to_link_batch(data): ...@@ -139,6 +141,7 @@ def to_link_batch(data):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link(sampler_type): def test_SubgraphSampler_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
...@@ -166,6 +169,7 @@ def test_SubgraphSampler_Link(sampler_type): ...@@ -166,6 +169,7 @@ def test_SubgraphSampler_Link(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_With_Negative(sampler_type): def test_SubgraphSampler_Link_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to( graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx() F.ctx()
) )
...@@ -216,6 +220,7 @@ def get_hetero_graph(): ...@@ -216,6 +220,7 @@ def get_hetero_graph():
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Node_Hetero(sampler_type): def test_SubgraphSampler_Node_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3) items = torch.arange(3)
names = "seed_nodes" names = "seed_nodes"
...@@ -244,6 +249,7 @@ def test_SubgraphSampler_Node_Hetero(sampler_type): ...@@ -244,6 +249,7 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero(sampler_type): def test_SubgraphSampler_Link_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs" first_names = "node_pairs"
...@@ -287,6 +293,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type): ...@@ -287,6 +293,7 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs" first_names = "node_pairs"
...@@ -331,6 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): ...@@ -331,6 +338,7 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type): def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs" first_names = "node_pairs"
...@@ -375,6 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type): ...@@ -375,6 +383,7 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs" first_names = "node_pairs"
...@@ -415,15 +424,18 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type): ...@@ -415,15 +424,18 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type): @pytest.mark.parametrize(
"replace",
[False, True],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
_check_sampler_type(sampler_type)
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
num_nodes = 5 num_nodes = 5
num_edges = 9 num_edges = 9
num_ntypes = 3 num_ntypes = 3
...@@ -477,40 +489,42 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type): ...@@ -477,40 +489,42 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type):
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type) sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts, replace=True) sampler_dp = sampler(item_sampler, graph, fanouts, replace=replace)
for data in sampler_dp: for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs: for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.sampled_csc.items(): for _, value in sampledsubgraph.sampled_csc.items():
assert torch.equal( assert torch.equal(
torch.ge(value.indices, torch.zeros(len(value.indices))), torch.ge(
torch.ones(len(value.indices)), value.indices,
torch.zeros(len(value.indices)).to(F.ctx()),
),
torch.ones(len(value.indices)).to(F.ctx()),
) )
assert torch.equal( assert torch.equal(
torch.ge(value.indptr, torch.zeros(len(value.indptr))), torch.ge(
torch.ones(len(value.indptr)), value.indptr, torch.zeros(len(value.indptr)).to(F.ctx())
),
torch.ones(len(value.indptr)).to(F.ctx()),
) )
for _, value in sampledsubgraph.original_column_node_ids.items(): for _, value in sampledsubgraph.original_column_node_ids.items():
assert torch.equal( assert torch.equal(
torch.ge(value, torch.zeros(len(value))), torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)), torch.ones(len(value)).to(F.ctx()),
) )
for _, value in sampledsubgraph.original_row_node_ids.items(): for _, value in sampledsubgraph.original_row_node_ids.items():
assert torch.equal( assert torch.equal(
torch.ge(value, torch.zeros(len(value))), torch.ge(value, torch.zeros(len(value)).to(F.ctx())),
torch.ones(len(value)), torch.ones(len(value)).to(F.ctx()),
) )
@unittest.skipIf(
F._default_context_str != "cpu",
reason="Fails due to randomness on the GPU.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sampler_type", "sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph( graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4]) ([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
) )
...@@ -551,7 +565,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): ...@@ -551,7 +565,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
torch.tensor([0, 1, 2, 4]).to(F.ctx()), torch.tensor([0, 1, 2, 4]).to(F.ctx()),
] ]
seeds = [ seeds = [
torch.tensor([0, 3, 4, 5, 2, 2, 4]).to(F.ctx()), torch.tensor([0, 2, 2, 3, 4, 4, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()), torch.tensor([0, 3, 4]).to(F.ctx()),
] ]
for data in datapipe: for data in datapipe:
...@@ -564,7 +578,8 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): ...@@ -564,7 +578,8 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
sampled_subgraph.sampled_csc.indptr, indptr[step] sampled_subgraph.sampled_csc.indptr, indptr[step]
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step] torch.sort(sampled_subgraph.original_column_node_ids)[0],
seeds[step],
) )
...@@ -573,6 +588,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): ...@@ -573,6 +588,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal], [SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
) )
def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type): def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2) items = torch.arange(2)
names = "seed_nodes" names = "seed_nodes"
...@@ -660,11 +676,11 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type): ...@@ -660,11 +676,11 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
@unittest.skipIf( @unittest.skipIf(
F._default_context_str != "cpu", F._default_context_str == "gpu",
reason="Fails due to randomness on the GPU.", reason="Fails due to different result on the GPU.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo(labor): def test_SubgraphSampler_unique_csc_format_Homo_cpu(labor):
torch.manual_seed(1205) torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4])) graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx()) graph = gb.from_dglgraph(graph, True).to(F.ctx())
...@@ -682,7 +698,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor): ...@@ -682,7 +698,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
item_sampler, item_sampler,
graph, graph,
fanouts, fanouts,
replace=False,
deduplicate=True, deduplicate=True,
) )
...@@ -719,6 +734,65 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor): ...@@ -719,6 +734,65 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
) )
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_gpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
seed_nodes = torch.LongTensor([0, 3, 4])
itemset = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(itemset, batch_size=len(seed_nodes)).copy_to(
F.ctx()
)
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),
torch.tensor([4, 3, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero(labor): def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment