"vscode:/vscode.git/clone" did not exist on "6b727842d7fd370ac057c092d913bf8557dd32c2"
Unverified Commit 1785acff authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Refactor csc format sampled subgraph. (#6553)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent fdeda8a8
"""Base types and utilities for Graph Bolt.""" """Base types and utilities for Graph Bolt."""
from dataclasses import dataclass
import torch import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
...@@ -13,6 +15,7 @@ __all__ = [ ...@@ -13,6 +15,7 @@ __all__ = [
"etype_tuple_to_str", "etype_tuple_to_str",
"CopyTo", "CopyTo",
"isin", "isin",
"CSCFormatBase",
] ]
CANONICAL_ETYPE_DELIMITER = ":" CANONICAL_ETYPE_DELIMITER = ":"
...@@ -111,3 +114,21 @@ class CopyTo(IterDataPipe): ...@@ -111,3 +114,21 @@ class CopyTo(IterDataPipe):
for data in self.datapipe: for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device) data = recursive_apply(data, apply_to, self.device)
yield data yield data
@dataclass
class CSCFormatBase:
r"""Basic class representing data in Compressed Sparse Column (CSC) format.
Examples
--------
>>> indptr = torch.tensor([0, 1, 3])
>>> indices = torch.tensor([1, 4, 2])
>>> csc_foramt_base = CSCFormatBase(indptr=indptr, indices=indices)
>>> print(csc_format_base.indptr)
... torch.tensor([0, 1, 3])
>>> print(csc_foramt_base)
... torch.tensor([1, 4, 2])
"""
indptr: torch.Tensor = None
indices: torch.Tensor = None
...@@ -15,7 +15,11 @@ from ...convert import to_homogeneous ...@@ -15,7 +15,11 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import FusedSampledSubgraphImpl from .sampled_subgraph_impl import (
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
__all__ = [ __all__ = [
...@@ -342,9 +346,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -342,9 +346,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes) _in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph) return self._convert_to_fused_sampled_subgraph(_in_subgraph)
def _convert_to_sampled_subgraph( def _convert_to_fused_sampled_subgraph(
self, self,
C_sampled_subgraph: torch.ScriptObject, C_sampled_subgraph: torch.ScriptObject,
): ):
...@@ -400,13 +404,109 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -400,13 +404,109 @@ class FusedCSCSamplingGraph(SamplingGraph):
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id]) homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes) return torch.cat(homogeneous_nodes)
def _convert_to_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
indptr = C_sampled_subgraph.indptr
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
# Pre-calculate the number of each etype
num = {}
for etype in type_per_edge:
num[etype.item()] = num.get(etype.item(), 0) + 1
# Preallocate
subgraph_indice_position = {}
subgraph_indice = {}
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
)
if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=original_edge_ids.dtype
)
subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for (i, seed) in enumerate(column):
l = indptr[i].item()
r = indptr[i + 1].item()
node_type = (
torch.searchsorted(
self.node_type_offset, seed, right=True
).item()
- 1
)
for (etype, etype_id) in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True
).item()
end = num_edges + l
subgraph_indptr[etype].append(
subgraph_indptr[etype][-1] + num_edges
)
offset = subgraph_indice_position[etype]
subgraph_indice_position[etype] += num_edges
subgraph_indice[etype][offset : offset + num_edges] = (
indices[l:end] - self.node_type_offset[src_ntype_id]
)
if has_original_eids:
original_hetero_edge_ids[etype][
offset : offset + num_edges
] = original_edge_ids[l:end]
l = end
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
node_pairs = {
etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
)
for etype in self.metadata.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
original_edge_ids=original_edge_ids,
)
def sample_neighbors( def sample_neighbors(
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl: deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -476,7 +576,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -476,7 +576,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name nodes, fanouts, replace, probs_name
) )
if deduplicate is True:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
...@@ -584,7 +686,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -584,7 +686,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl: deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs `Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
...@@ -667,6 +770,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -667,6 +770,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
probs_name, probs_name,
) )
if deduplicate:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def sample_negative_edges_uniform( def sample_negative_edges_uniform(
......
...@@ -115,6 +115,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -115,6 +115,7 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop], self.fanouts[hop],
self.replace, self.replace,
self.prob_name, self.prob_name,
self.deduplicate,
) )
if self.deduplicate: if self.deduplicate:
( (
......
...@@ -5,10 +5,10 @@ from typing import Dict, Tuple, Union ...@@ -5,10 +5,10 @@ from typing import Dict, Tuple, Union
import torch import torch
from ..base import etype_str_to_tuple from ..base import CSCFormatBase, etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph from ..sampled_subgraph import SampledSubgraph
__all__ = ["FusedSampledSubgraphImpl"] __all__ = ["FusedSampledSubgraphImpl", "SampledSubgraphImpl"]
@dataclass @dataclass
...@@ -67,3 +67,65 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -67,3 +67,65 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
assert all( assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of CSCSamplingGraph.
Examples
--------
>>> node_pairs = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
{"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
CSCFormatBase,
Dict[str, CSCFormatBase],
] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
pair.indptr is not None and pair.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(pair.indptr, torch.Tensor) and isinstance(
pair.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
self.node_pairs.indptr is not None
and self.node_pairs.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
self.node_pairs.indptr, torch.Tensor
) and isinstance(
self.node_pairs.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment