Unverified Commit 51ff8255 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove `FusedSampledSubgraphImpl`. (#6858)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent af990f37
"""Sampled subgraph for FusedCSCSamplingGraph.""" """Sampled subgraph for FusedCSCSamplingGraph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union from typing import Dict, Union
import torch import torch
...@@ -9,69 +9,7 @@ from ..base import CSCFormatBase, etype_str_to_tuple ...@@ -9,69 +9,7 @@ from ..base import CSCFormatBase, etype_str_to_tuple
from ..internal import get_attributes from ..internal import get_attributes
from ..sampled_subgraph import SampledSubgraph from ..sampled_subgraph import SampledSubgraph
__all__ = ["FusedSampledSubgraphImpl", "SampledSubgraphImpl"] __all__ = ["SampledSubgraphImpl"]
@dataclass
class FusedSampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples
--------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl(
... sampled_csc=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.sampled_csc)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
sampled_csc: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.sampled_csc, dict):
for etype, pair in self.sampled_csc.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
isinstance(pair, tuple) and len(pair) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in pair
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
isinstance(self.sampled_csc, tuple)
and len(self.sampled_csc) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in self.sampled_csc
), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
return _sampled_subgraph_str(self, "FusedSampledSubgraphImpl")
@dataclass @dataclass
......
...@@ -5,10 +5,7 @@ import backend as F ...@@ -5,10 +5,7 @@ import backend as F
import dgl.graphbolt as gb import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
from dgl.graphbolt.impl.sampled_subgraph_impl import ( from dgl.graphbolt.impl.sampled_subgraph_impl import SampledSubgraphImpl
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
def _assert_container_equal(lhs, rhs): def _assert_container_equal(lhs, rhs):
......
...@@ -80,12 +80,15 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -80,12 +80,15 @@ def test_FeatureFetcher_with_edges_homo():
def add_node_and_edge_ids(seeds): def add_node_and_edge_ids(seeds):
subgraphs = [] subgraphs = []
for _ in range(3): for _ in range(3):
range_tensor = torch.arange(10) sampled_csc = gb.CSCFormatBase(
indptr=torch.arange(11),
indices=torch.arange(10),
)
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.SampledSubgraphImpl(
sampled_csc=(range_tensor, range_tensor), sampled_csc=sampled_csc,
original_column_node_ids=range_tensor, original_column_node_ids=torch.arange(10),
original_row_node_ids=range_tensor, original_row_node_ids=torch.arange(10),
original_edge_ids=torch.randint( original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,) 0, graph.total_num_edges, (10,)
), ),
...@@ -183,15 +186,15 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -183,15 +186,15 @@ def test_FeatureFetcher_with_edges_hetero():
} }
for _ in range(3): for _ in range(3):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.SampledSubgraphImpl(
sampled_csc={ sampled_csc={
"n1:e1:n2": ( "n1:e1:n2": gb.CSCFormatBase(
torch.arange(10), indptr=torch.arange(11),
torch.arange(10), indices=torch.arange(10),
), ),
"n2:e2:n1": ( "n2:e2:n1": gb.CSCFormatBase(
torch.arange(10), indptr=torch.arange(11),
torch.arange(10), indices=torch.arange(10),
), ),
}, },
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment