Unverified Commit 51ff8255 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove `FusedSampledSubgraphImpl`. (#6858)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent af990f37
"""Sampled subgraph for FusedCSCSamplingGraph."""
# pylint: disable= invalid-name
from dataclasses import dataclass
from typing import Dict, Tuple, Union
from typing import Dict, Union
import torch
......@@ -9,69 +9,7 @@ from ..base import CSCFormatBase, etype_str_to_tuple
from ..internal import get_attributes
from ..sampled_subgraph import SampledSubgraph
__all__ = ["FusedSampledSubgraphImpl", "SampledSubgraphImpl"]
@dataclass
class FusedSampledSubgraphImpl(SampledSubgraph):
r"""Sampled subgraph of FusedCSCSamplingGraph.
Examples
--------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl(
... sampled_csc=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.sampled_csc)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
sampled_csc: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.sampled_csc, dict):
for etype, pair in self.sampled_csc.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
isinstance(pair, tuple) and len(pair) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in pair
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
isinstance(self.sampled_csc, tuple)
and len(self.sampled_csc) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in self.sampled_csc
), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
return _sampled_subgraph_str(self, "FusedSampledSubgraphImpl")
__all__ = ["SampledSubgraphImpl"]
@dataclass
......
......@@ -5,10 +5,7 @@ import backend as F
import dgl.graphbolt as gb
import pytest
import torch
from dgl.graphbolt.impl.sampled_subgraph_impl import (
FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
from dgl.graphbolt.impl.sampled_subgraph_impl import SampledSubgraphImpl
def _assert_container_equal(lhs, rhs):
......
......@@ -80,12 +80,15 @@ def test_FeatureFetcher_with_edges_homo():
def add_node_and_edge_ids(seeds):
subgraphs = []
for _ in range(3):
range_tensor = torch.arange(10)
sampled_csc = gb.CSCFormatBase(
indptr=torch.arange(11),
indices=torch.arange(10),
)
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=(range_tensor, range_tensor),
original_column_node_ids=range_tensor,
original_row_node_ids=range_tensor,
gb.SampledSubgraphImpl(
sampled_csc=sampled_csc,
original_column_node_ids=torch.arange(10),
original_row_node_ids=torch.arange(10),
original_edge_ids=torch.randint(
0, graph.total_num_edges, (10,)
),
......@@ -183,15 +186,15 @@ def test_FeatureFetcher_with_edges_hetero():
}
for _ in range(3):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
gb.SampledSubgraphImpl(
sampled_csc={
"n1:e1:n2": (
torch.arange(10),
torch.arange(10),
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.arange(11),
indices=torch.arange(10),
),
"n2:e2:n1": (
torch.arange(10),
torch.arange(10),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.arange(11),
indices=torch.arange(10),
),
},
original_column_node_ids=original_column_node_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment