Unverified Commit 06074d73 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enrich node types for input/output nodes of sampled subgraph (#6348)

parent adf49937
......@@ -242,8 +242,6 @@ class CSCSamplingGraph:
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
if mask.count_nonzero() == 0:
continue
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id]
......
......@@ -78,6 +78,7 @@ class NeighborSampler(SubgraphSampler):
3
"""
super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors.
self.fanouts = []
for fanout in fanouts:
......@@ -91,6 +92,13 @@ class NeighborSampler(SubgraphSampler):
def _sample_subgraphs(self, seeds):
subgraphs = []
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys())
seeds = {
ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes
}
for hop in range(num_layers):
subgraph = self.sampler(
seeds,
......
......@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor):
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
),
"n1:e1:n2": (
torch.LongTensor([]),
torch.LongTensor([]),
),
}
assert len(subgraph.node_pairs) == 1
assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
......
......@@ -129,6 +129,23 @@ def get_hetero_graph():
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2
for minibatch in sampler_dp:
blocks = minibatch.to_dgl_blocks()
assert len(blocks) == num_layer
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment