"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "ea7df4db647b709e0ddc9830300be7cca625febd"
Unverified Commit fa1ae3b7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from `FusedCSCSamplingGraph`. (#6849)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent f81e514d
......@@ -12,11 +12,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import (
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
__all__ = [
......@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def in_subgraph(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming
......@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The in subgraph.
Examples
......@@ -389,59 +383,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num
)
row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = (row, column)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id]
)
node_pairs[etype] = (hetero_row, hetero_column)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl(
sampled_csc=node_pairs, original_edge_ids=original_edge_ids
)
return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = []
......@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]
SampledSubgraphImpl
The sampled subgraph.
Examples
......@@ -638,10 +578,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name
)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
# TODO: clean up once the migration is done.
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
......@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The sampled subgraph.
Examples
......@@ -843,11 +778,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
has_original_eids,
probs_name,
)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def temporal_sample_neighbors(
self,
......@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns
-------
FusedSampledSubgraphImpl
SampledSubgraphImpl
The sampled subgraph.
"""
if isinstance(nodes, dict):
......
......@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples),
torch.tensor([fanout]),
output_cscformat=False,
)
assert len(subgraph.sampled_csc[0]) <= num_samples
assert len(subgraph.sampled_csc.indices) <= num_samples
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment