Unverified Commit ec5c515c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add parameter to `NeighborSampler` to support NOT deduplication. (#6419)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 34af2d68
......@@ -44,6 +44,10 @@ class NeighborSampler(SubgraphSampler):
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.
Examples
-------
......@@ -77,6 +81,7 @@ class NeighborSampler(SubgraphSampler):
fanouts,
replace=False,
prob_name=None,
deduplicate=True,
):
super().__init__(datapipe)
self.graph = graph
......@@ -88,6 +93,7 @@ class NeighborSampler(SubgraphSampler):
self.fanouts.append(fanout)
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.sampler = graph.sample_neighbors
def _sample_subgraphs(self, seeds):
......@@ -108,9 +114,12 @@ class NeighborSampler(SubgraphSampler):
self.prob_name,
)
original_column_node_ids = seeds
seeds, compacted_node_pairs = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
if self.deduplicate:
seeds, compacted_node_pairs = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
else:
raise RuntimeError("Not implemented yet.")
subgraph = SampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=original_column_node_ids,
......@@ -166,6 +175,10 @@ class LayerNeighborSampler(NeighborSampler):
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.
Examples
-------
......@@ -202,6 +215,9 @@ class LayerNeighborSampler(NeighborSampler):
fanouts,
replace=False,
prob_name=None,
deduplicate=True,
):
super().__init__(datapipe, graph, fanouts, replace, prob_name)
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate
)
self.sampler = graph.sample_layer_neighbors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment