Unverified Commit abd4a7e8 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from neighborsampler. (#6844)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 7eca0612
......@@ -3,14 +3,10 @@
import torch
from torch.utils.data import functional_datapipe
from ..internal import (
compact_csc_format,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"]
......@@ -102,8 +98,6 @@ class NeighborSampler(SubgraphSampler):
replace=False,
prob_name=None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=True,
):
super().__init__(datapipe)
self.graph = graph
......@@ -116,7 +110,6 @@ class NeighborSampler(SubgraphSampler):
self.replace = replace
self.prob_name = prob_name
self.deduplicate = deduplicate
self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors
def sample_subgraphs(self, seeds):
......@@ -135,35 +128,18 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop],
self.replace,
self.prob_name,
not self.deduplicate or self.output_cscformat,
)
if self.deduplicate:
if self.output_cscformat:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(
subgraph.sampled_csc, seeds
)
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(
subgraph.sampled_csc, seeds
)
subgraph = FusedSampledSubgraphImpl(
sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
......@@ -279,8 +255,6 @@ class LayerNeighborSampler(NeighborSampler):
replace=False,
prob_name=None,
deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=True,
):
super().__init__(
datapipe,
......@@ -289,6 +263,5 @@ class LayerNeighborSampler(NeighborSampler):
replace,
prob_name,
deduplicate,
output_cscformat,
)
self.sampler = graph.sample_layer_neighbors
......@@ -402,7 +402,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
fanouts,
replace=False,
deduplicate=True,
output_cscformat=True,
)
original_row_node_ids = [
......@@ -450,7 +449,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph,
fanouts,
deduplicate=True,
output_cscformat=True,
)
csc_formats = [
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment