"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4ab94cb7aa01828960f48f55fa477110f4d191fb"
Unverified Commit 06074d73 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enrich node types for input/output nodes of sampled subgraph (#6348)

parent adf49937
...@@ -242,8 +242,6 @@ class CSCSamplingGraph: ...@@ -242,8 +242,6 @@ class CSCSamplingGraph:
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id mask = type_per_edge == etype_id
if mask.count_nonzero() == 0:
continue
hetero_row = row[mask] - self.node_type_offset[src_ntype_id] hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = ( hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id] column[mask] - self.node_type_offset[dst_ntype_id]
......
...@@ -78,6 +78,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -78,6 +78,7 @@ class NeighborSampler(SubgraphSampler):
3 3
""" """
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph
# Convert fanouts to a list of tensors. # Convert fanouts to a list of tensors.
self.fanouts = [] self.fanouts = []
for fanout in fanouts: for fanout in fanouts:
...@@ -91,6 +92,13 @@ class NeighborSampler(SubgraphSampler): ...@@ -91,6 +92,13 @@ class NeighborSampler(SubgraphSampler):
def _sample_subgraphs(self, seeds): def _sample_subgraphs(self, seeds):
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys())
seeds = {
ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes
}
for hop in range(num_layers): for hop in range(num_layers):
subgraph = self.sampler( subgraph = self.sampler(
seeds, seeds,
......
...@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor): ...@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor):
torch.LongTensor([0, 2]), torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]), torch.LongTensor([0, 0]),
), ),
"n1:e1:n2": (
torch.LongTensor([]),
torch.LongTensor([]),
),
} }
assert len(subgraph.node_pairs) == 1 assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0]) assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1]) assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
......
...@@ -129,6 +129,23 @@ def get_hetero_graph(): ...@@ -129,6 +129,23 @@ def get_hetero_graph():
) )
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node_Hetero(labor):
graph = get_hetero_graph()
itemset = gb.ItemSetDict(
{"n2": gb.ItemSet(torch.arange(3), names="seed_nodes")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(item_sampler, graph, fanouts)
assert len(list(sampler_dp)) == 2
for minibatch in sampler_dp:
blocks = minibatch.to_dgl_blocks()
assert len(blocks) == num_layer
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Link_Hetero(labor): def test_SubgraphSampler_Link_Hetero(labor):
graph = get_hetero_graph() graph = get_hetero_graph()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment