"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bfe8b4131505d4ee722276abfe1d752234ffd465"
Unverified Commit abd4a7e8 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from neighborsampler. (#6844)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 7eca0612
...@@ -3,14 +3,10 @@ ...@@ -3,14 +3,10 @@
import torch import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from ..internal import ( from ..internal import compact_csc_format, unique_and_compact_csc_formats
compact_csc_format,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from ..subgraph_sampler import SubgraphSampler from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"] __all__ = ["NeighborSampler", "LayerNeighborSampler"]
...@@ -102,8 +98,6 @@ class NeighborSampler(SubgraphSampler): ...@@ -102,8 +98,6 @@ class NeighborSampler(SubgraphSampler):
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=True,
): ):
super().__init__(datapipe) super().__init__(datapipe)
self.graph = graph self.graph = graph
...@@ -116,7 +110,6 @@ class NeighborSampler(SubgraphSampler): ...@@ -116,7 +110,6 @@ class NeighborSampler(SubgraphSampler):
self.replace = replace self.replace = replace
self.prob_name = prob_name self.prob_name = prob_name
self.deduplicate = deduplicate self.deduplicate = deduplicate
self.output_cscformat = output_cscformat
self.sampler = graph.sample_neighbors self.sampler = graph.sample_neighbors
def sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
...@@ -135,35 +128,18 @@ class NeighborSampler(SubgraphSampler): ...@@ -135,35 +128,18 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop], self.fanouts[hop],
self.replace, self.replace,
self.prob_name, self.prob_name,
not self.deduplicate or self.output_cscformat,
) )
if self.deduplicate: if self.deduplicate:
if self.output_cscformat: (
( original_row_node_ids,
original_row_node_ids, compacted_csc_format,
compacted_csc_format, ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
) = unique_and_compact_csc_formats( subgraph = SampledSubgraphImpl(
subgraph.sampled_csc, seeds sampled_csc=compacted_csc_format,
) original_column_node_ids=seeds,
subgraph = SampledSubgraphImpl( original_row_node_ids=original_row_node_ids,
sampled_csc=compacted_csc_format, original_edge_ids=subgraph.original_edge_ids,
original_column_node_ids=seeds, )
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(
subgraph.sampled_csc, seeds
)
subgraph = FusedSampledSubgraphImpl(
sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else: else:
( (
original_row_node_ids, original_row_node_ids,
...@@ -279,8 +255,6 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -279,8 +255,6 @@ class LayerNeighborSampler(NeighborSampler):
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
# TODO: clean up once the migration is done.
output_cscformat=True,
): ):
super().__init__( super().__init__(
datapipe, datapipe,
...@@ -289,6 +263,5 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -289,6 +263,5 @@ class LayerNeighborSampler(NeighborSampler):
replace, replace,
prob_name, prob_name,
deduplicate, deduplicate,
output_cscformat,
) )
self.sampler = graph.sample_layer_neighbors self.sampler = graph.sample_layer_neighbors
...@@ -402,7 +402,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor): ...@@ -402,7 +402,6 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
fanouts, fanouts,
replace=False, replace=False,
deduplicate=True, deduplicate=True,
output_cscformat=True,
) )
original_row_node_ids = [ original_row_node_ids = [
...@@ -450,7 +449,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor): ...@@ -450,7 +449,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
graph, graph,
fanouts, fanouts,
deduplicate=True, deduplicate=True,
output_cscformat=True,
) )
csc_formats = [ csc_formats = [
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment