Unverified Commit 57281e9f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] sample with unknown etype (#6888)

parent e9deff7d
...@@ -326,6 +326,95 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type): ...@@ -326,6 +326,95 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
# "e11" and "e22" are not valid edge types.
itemset = gb.ItemSetDict(
{
"n1:e11:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e22:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "node_pairs"
second_items = torch.LongTensor([[0, 0, 1, 1, 2, 2], [0, 1, 1, 0, 0, 1]]).T
second_names = "node_pairs"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
first_items = (first_items, torch.randint(0, 10, (4,)))
first_names = (first_names, "timestamp")
second_items = (second_items, torch.randint(0, 10, (6,)))
second_names = (second_names, "timestamp")
# "e11" and "e22" are not valid edge types.
itemset = gb.ItemSetDict(
{
"n1:e11:n2": gb.ItemSet(
first_items,
names=first_names,
),
"n2:e22:n1": gb.ItemSet(
second_items,
names=second_names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
datapipe = gb.UniformNegativeSampler(datapipe, graph, 1)
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
assert len(list(datapipe)) == 5
@unittest.skipIf( @unittest.skipIf(
F._default_context_str != "cpu", F._default_context_str != "cpu",
reason="Sampling with replacement not yet supported on GPU.", reason="Sampling with replacement not yet supported on GPU.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment