Unverified Commit 9286621c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add multi-fanout per layer test. (#6890)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 25eabbb4
...@@ -872,3 +872,73 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor): ...@@ -872,3 +872,73 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
sampled_subgraph.sampled_csc[etype].indptr, sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()), csc_formats[step][etype].indptr.to(F.ctx()),
) )
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Heterogenous sampling is not supported on GPU yet.",
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
graph = get_hetero_graph().to(F.ctx())
items_n1 = torch.tensor([0])
items_n2 = torch.tensor([1])
names = "seed_nodes"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
# All edges can be sampled.
items_n1 = (items_n1, torch.tensor([10]))
items_n2 = (items_n2, torch.tensor([10]))
names = (names, "timestamp")
itemset = gb.ItemSetDict(
{
"n1": gb.ItemSet(items=items_n1, names=names),
"n2": gb.ItemSet(items=items_n2, names=names),
}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
# The number of edges to be sampled for each edge types of each node.
fanouts = [torch.LongTensor([2, 1]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
sampler_dp = sampler(item_sampler, graph, fanouts)
if sampler_type == SamplerType.Temporal:
indices_len = [
{
"n1:e1:n2": 4,
"n2:e2:n1": 3,
},
{
"n1:e1:n2": 2,
"n2:e2:n1": 1,
},
]
else:
indices_len = [
{
"n1:e1:n2": 4,
"n2:e2:n1": 2,
},
{
"n1:e1:n2": 2,
"n2:e2:n1": 1,
},
]
for minibatch in sampler_dp:
for step, sampled_subgraph in enumerate(minibatch.sampled_subgraphs):
assert (
len(sampled_subgraph.sampled_csc["n1:e1:n2"].indices)
== indices_len[step]["n1:e1:n2"]
)
assert (
len(sampled_subgraph.sampled_csc["n2:e2:n1"].indices)
== indices_len[step]["n2:e2:n1"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment