"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f86212edb5012f36f30ed8f79513b2b3f54cf1ea"
Unverified Commit 25517e8f authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add sampled sub graph (#5964)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 9d63f3ea
...@@ -14,6 +14,7 @@ from .dataset import * ...@@ -14,6 +14,7 @@ from .dataset import *
from .impl import * from .impl import *
from .dataloader import * from .dataloader import *
from .subgraph_sampler import * from .subgraph_sampler import *
from .sampled_subgraph import *
def load_graphbolt(): def load_graphbolt():
......
...@@ -3,3 +3,4 @@ from .ondisk_dataset import * ...@@ -3,3 +3,4 @@ from .ondisk_dataset import *
from .ondisk_metadata import * from .ondisk_metadata import *
from .torch_based_feature_store import * from .torch_based_feature_store import *
from .csc_sampling_graph import * from .csc_sampling_graph import *
from .sampled_subgraph_impl import *
...@@ -196,6 +196,7 @@ class CSCSamplingGraph: ...@@ -196,6 +196,7 @@ class CSCSamplingGraph:
assert len(torch.unique(nodes)) == len( assert len(torch.unique(nodes)) == len(
nodes nodes
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
# TODO: change the result to 'SampledSubgraphImpl'.
return self._c_csc_graph.in_subgraph(nodes) return self._c_csc_graph.in_subgraph(nodes)
def sample_neighbors( def sample_neighbors(
......
"""Sampled subgraph for CSCSamplingGraph."""
# pylint: disable= invalid-name
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
from ..sampled_subgraph import SampledSubgraph
@dataclass
class SampledSubgraphImpl(SampledSubgraph):
r"""Class for sampled subgraph specific for CSCSamplingGraph.
Examples
--------
>>> node_pairs = {('A', 'B', 'relation'): (torch.tensor([1, 2, 3]),
... torch.tensor([4, 5, 6]))}
>>> reverse_column_node_ids = {'A': torch.tensor([7, 8, 9]),
... 'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15]),
... 'B': torch.tensor([16, 17, 18])}
>>> reverse_edge_ids = {('A', 'B', 'relation'): torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
... reverse_row_node_ids=reverse_row_node_ids,
... reverse_edge_ids=reverse_edge_ids
... )
>>> print(subgraph.node_pairs)
{('A', 'B', 'relation'): (tensor([1, 2, 3]), tensor([4, 5, 6]))}
>>> print(subgraph.reverse_column_node_ids)
{'A': tensor([7, 8, 9]), 'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids)
{'A': tensor([13, 14, 15]), 'B': tensor([16, 17, 18])}
>>> print(subgraph.reverse_edge_ids)
{('A', 'B', 'relation'): tensor([19, 20, 21])}
"""
node_pairs: Union[
Dict[Tuple[str, str, str], Tuple[torch.tensor, torch.tensor]],
Tuple[torch.tensor, torch.tensor],
] = None
reverse_column_node_ids: Union[Dict[str, torch.tensor], torch.tensor] = None
reverse_row_node_ids: Union[Dict[str, torch.tensor], torch.tensor] = None
reverse_edge_ids: Union[
Dict[Tuple[str, str, str], torch.tensor], torch.tensor
] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
assert (
isinstance(etype, tuple) and len(etype) == 3
), "Edge type should be a triplet of strings (str, str, str)."
assert all(
isinstance(item, str) for item in etype
), "Edge type should be a triplet of strings (str, str, str)."
assert (
isinstance(pair, tuple) and len(pair) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in pair
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
isinstance(self.node_pairs, tuple) and len(self.node_pairs) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor."
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
from typing import Dict, Tuple
import torch
class SampledSubgraph:
r"""An abstract class for sampled subgraph. In the context of a
heterogeneous graph, each field should be of `Dict` type. Otherwise,
for homogeneous graphs, each field should correspond to its respective
value type."""
@property
def node_pairs(
self,
) -> Tuple[torch.Tensor] or Dict[(str, str, str), Tuple[torch.Tensor]]:
"""Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
- If `node_pairs` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside
is heterogeneous ids."""
raise NotImplementedError
@property
def reverse_column_node_ids(
self,
) -> torch.Tensor or Dict[str, torch.Tensor]:
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the column.
- If `reverse_column_node_ids` is a tensor: It represents the
original node ids.
- If `reverse_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs`
column IDs match these compacted ones.
"""
return None
@property
def reverse_row_node_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]:
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the row.
- If `reverse_row_node_ids` is a tensor: It represents the
original node ids.
- If `reverse_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs`
row IDs match these compacted ones."""
return None
@property
def reverse_edge_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
- If `reverse_edge_ids` is a tensor: It represents the
original edge ids.
- If `reverse_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original
heterogeneous edge ids.
"""
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment