Unverified Commit 1785acff authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Refactor csc format sampled subgraph. (#6553)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent fdeda8a8
"""Base types and utilities for Graph Bolt."""
from dataclasses import dataclass
import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
......@@ -13,6 +15,7 @@ __all__ = [
"etype_tuple_to_str",
"CopyTo",
"isin",
"CSCFormatBase",
]
CANONICAL_ETYPE_DELIMITER = ":"
......@@ -111,3 +114,21 @@ class CopyTo(IterDataPipe):
for data in self.datapipe:
data = recursive_apply(data, apply_to, self.device)
yield data
@dataclass
class CSCFormatBase:
r"""Basic class representing data in Compressed Sparse Column (CSC) format.
Examples
--------
>>> indptr = torch.tensor([0, 1, 3])
>>> indices = torch.tensor([1, 4, 2])
>>> csc_foramt_base = CSCFormatBase(indptr=indptr, indices=indices)
>>> print(csc_format_base.indptr)
... torch.tensor([0, 1, 3])
>>> print(csc_foramt_base)
... torch.tensor([1, 4, 2])
"""
indptr: torch.Tensor = None
indices: torch.Tensor = None
......@@ -15,7 +15,11 @@ from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
from ..sampling_graph import SamplingGraph
from .sampled_subgraph_impl import FusedSampledSubgraphImpl
from .sampled_subgraph_impl import (
CSCFormatBase,
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
__all__ = [
......@@ -342,9 +346,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "Nodes cannot have duplicate values."
_in_subgraph = self._c_csc_graph.in_subgraph(nodes)
return self._convert_to_sampled_subgraph(_in_subgraph)
return self._convert_to_fused_sampled_subgraph(_in_subgraph)
def _convert_to_sampled_subgraph(
def _convert_to_fused_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
):
......@@ -400,13 +404,109 @@ class FusedCSCSamplingGraph(SamplingGraph):
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
def _convert_to_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
indptr = C_sampled_subgraph.indptr
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
if has_original_eids:
original_edge_ids = self.edge_attributes[ORIGINAL_EDGE_ID][
original_edge_ids
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
# Pre-calculate the number of each etype
num = {}
for etype in type_per_edge:
num[etype.item()] = num.get(etype.item(), 0) + 1
# Preallocate
subgraph_indice_position = {}
subgraph_indice = {}
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
)
if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=original_edge_ids.dtype
)
subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for (i, seed) in enumerate(column):
l = indptr[i].item()
r = indptr[i + 1].item()
node_type = (
torch.searchsorted(
self.node_type_offset, seed, right=True
).item()
- 1
)
for (etype, etype_id) in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True
).item()
end = num_edges + l
subgraph_indptr[etype].append(
subgraph_indptr[etype][-1] + num_edges
)
offset = subgraph_indice_position[etype]
subgraph_indice_position[etype] += num_edges
subgraph_indice[etype][offset : offset + num_edges] = (
indices[l:end] - self.node_type_offset[src_ntype_id]
)
if has_original_eids:
original_hetero_edge_ids[etype][
offset : offset + num_edges
] = original_edge_ids[l:end]
l = end
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
node_pairs = {
etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
)
for etype in self.metadata.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
original_edge_ids=original_edge_ids,
)
def sample_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl:
deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -476,8 +576,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, probs_name
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
if deduplicate is True:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......@@ -584,7 +686,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> FusedSampledSubgraphImpl:
deduplicate=True,
) -> Union[FusedSampledSubgraphImpl, SampledSubgraphImpl]:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
......@@ -667,7 +770,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
probs_name,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
if deduplicate:
return self._convert_to_fused_sampled_subgraph(C_sampled_subgraph)
else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio
......
......@@ -115,6 +115,7 @@ class NeighborSampler(SubgraphSampler):
self.fanouts[hop],
self.replace,
self.prob_name,
self.deduplicate,
)
if self.deduplicate:
(
......
......@@ -5,10 +5,10 @@ from typing import Dict, Tuple, Union
import torch
from ..base import etype_str_to_tuple
from ..base import CSCFormatBase, etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph
__all__ = ["FusedSampledSubgraphImpl"]
__all__ = ["FusedSampledSubgraphImpl", "SampledSubgraphImpl"]
@dataclass
......@@ -67,3 +67,65 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor."
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of CSCSamplingGraph.
Examples
--------
>>> node_pairs = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
{"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
CSCFormatBase,
Dict[str, CSCFormatBase],
] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
pair.indptr is not None and pair.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(pair.indptr, torch.Tensor) and isinstance(
pair.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
self.node_pairs.indptr is not None
and self.node_pairs.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
self.node_pairs.indptr, torch.Tensor
) and isinstance(
self.node_pairs.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment