Unverified Commit fa1ae3b7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from `FusedCSCSamplingGraph`. (#6849)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent f81e514d
...@@ -12,11 +12,7 @@ from ...convert import to_homogeneous ...@@ -12,11 +12,7 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import ( from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
__all__ = [ __all__ = [
...@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -324,9 +320,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def in_subgraph( def in_subgraph(
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Return the subgraph induced on the inbound edges of the given nodes. """Return the subgraph induced on the inbound edges of the given nodes.
An in subgraph is equivalent to creating a new graph using the incoming An in subgraph is equivalent to creating a new graph using the incoming
...@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -344,7 +338,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The in subgraph. The in subgraph.
Examples Examples
...@@ -389,60 +383,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -389,60 +383,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes) _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
else:
return self._convert_to_sampled_subgraph(_in_subgraph) return self._convert_to_sampled_subgraph(_in_subgraph)
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'FusedSampledSubgraphImpl'."""
column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
)
column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num
)
row = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = (row, column)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
column[mask] - self.node_type_offset[dst_ntype_id]
)
node_pairs[etype] = (hetero_row, hetero_column)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[mask]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl(
sampled_csc=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = [] homogeneous_nodes = []
homogeneous_timestamps = [] homogeneous_timestamps = []
...@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -560,9 +502,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -604,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
Union[FusedSampledSubgraphImpl, SampledSubgraphImpl] SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -638,9 +578,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -638,9 +578,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name nodes, fanouts, replace, probs_name
) )
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
...@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -753,9 +690,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
# TODO: clean up once the migration is done. ) -> SampledSubgraphImpl:
output_cscformat=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
...@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -799,7 +734,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
Examples Examples
...@@ -843,10 +778,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -843,10 +778,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
has_original_eids, has_original_eids,
probs_name, probs_name,
) )
if not output_cscformat:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def temporal_sample_neighbors( def temporal_sample_neighbors(
...@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -906,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
FusedSampledSubgraphImpl SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
......
...@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1143,9 +1143,8 @@ def test_OnDiskDataset_preprocess_homogeneous():
subgraph = fused_csc_sampling_graph.sample_neighbors( subgraph = fused_csc_sampling_graph.sample_neighbors(
torch.arange(num_samples), torch.arange(num_samples),
torch.tensor([fanout]), torch.tensor([fanout]),
output_cscformat=False,
) )
assert len(subgraph.sampled_csc[0]) <= num_samples assert len(subgraph.sampled_csc.indices) <= num_samples
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified. # All metadata fields are specified.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment