Unverified Commit ec5c515c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add parameter to `NeighborSampler` to support NOT deduplication. (#6419)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 34af2d68
...@@ -44,6 +44,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -44,6 +44,10 @@ class NeighborSampler(SubgraphSampler):
probabilities corresponding to each neighboring edge of a node. probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges. of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.
Examples Examples
------- -------
...@@ -77,6 +81,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -77,6 +81,7 @@ class NeighborSampler(SubgraphSampler):
fanouts, fanouts,
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
...@@ -88,6 +93,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -88,6 +93,7 @@ class NeighborSampler(SubgraphSampler):
self.fanouts.append(fanout) self.fanouts.append(fanout)
self.replace = replace self.replace = replace
self.prob_name = prob_name self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds): def _sample_subgraphs(self, seeds):
...@@ -108,9 +114,12 @@ class NeighborSampler(SubgraphSampler): ...@@ -108,9 +114,12 @@ class NeighborSampler(SubgraphSampler):
self.prob_name, self.prob_name,
) )
original_column_node_ids = seeds original_column_node_ids = seeds
seeds, compacted_node_pairs = unique_and_compact_node_pairs( if self.deduplicate:
subgraph.node_pairs, seeds seeds, compacted_node_pairs = unique_and_compact_node_pairs(
) subgraph.node_pairs, seeds
)
else:
raise RuntimeError("Not implemented yet.")
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_node_pairs, node_pairs=compacted_node_pairs,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
...@@ -166,6 +175,10 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -166,6 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
probabilities corresponding to each neighboring edge of a node. probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges. of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.
Examples Examples
------- -------
...@@ -202,6 +215,9 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -202,6 +215,9 @@ class LayerNeighborSampler(NeighborSampler):
fanouts, fanouts,
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True,
): ):
super().__init__(datapipe, graph, fanouts, replace, prob_name) super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate
)
self.sampler = graph.sample_layer_neighbors self.sampler = graph.sample_layer_neighbors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment