Unverified Commit f08d2a8a authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Add `all_nodes_set` to `OnDiskDataset` (#6434)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 0b00b581
"""GraphBolt Dataset."""
from typing import Dict, List
from typing import Dict, List, Union
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
from .sampling_graph import SamplingGraph
__all__ = [
"Task",
......@@ -73,7 +74,7 @@ class Dataset:
raise NotImplementedError
@property
def graph(self) -> object:
def graph(self) -> SamplingGraph:
"""Return the graph."""
raise NotImplementedError
......@@ -86,3 +87,8 @@ class Dataset:
def dataset_name(self) -> str:
"""Return the dataset name."""
raise NotImplementedError
@property
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the itemset containing all nodes."""
raise NotImplementedError
......@@ -15,6 +15,7 @@ from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from ..utils import read_data, save_data
from .csc_sampling_graph import (
CSCSamplingGraph,
......@@ -389,6 +390,7 @@ class OnDiskDataset(Dataset):
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._tasks = self._init_tasks(self._meta.tasks)
self._all_nodes_set = self._init_all_nodes_set(self._graph)
return self
@property
......@@ -402,7 +404,7 @@ class OnDiskDataset(Dataset):
return self._tasks
@property
def graph(self) -> object:
def graph(self) -> SamplingGraph:
"""Return the graph."""
return self._graph
......@@ -416,6 +418,11 @@ class OnDiskDataset(Dataset):
"""Return the dataset name."""
return self._dataset_name
@property
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the itemset containing all nodes."""
return self._all_nodes_set
def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
"""Initialize the tasks."""
ret = []
......@@ -475,6 +482,19 @@ class OnDiskDataset(Dataset):
ret = ItemSetDict(data)
return ret
def _init_all_nodes_set(self, graph) -> Union[ItemSet, ItemSetDict]:
if graph is None:
return None
num_nodes = graph.num_nodes
if isinstance(num_nodes, int):
return ItemSet(num_nodes)
else:
data = {
node_type: ItemSet(num_node)
for node_type, num_node in num_nodes.items()
}
return ItemSetDict(data)
class BuiltinDataset(OnDiskDataset):
"""A utility class to download built-in dataset from AWS S3 and load it as
......
......@@ -1711,6 +1711,71 @@ def test_OnDiskDataset_load_tasks():
dataset = None
def test_OnDiskDataset_all_nodes_set_homo():
"""Test homograph's all nodes set of OnDiskDataset."""
csc_indptr, indices = gbt.random_homo_graph(1000, 10 * 1000)
graph = gb.from_csc(csc_indptr, indices)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSet)
for i, item in enumerate(all_nodes_set):
assert i == item
dataset = None
def test_OnDiskDataset_all_nodes_set_hetero():
"""Test heterograph's all nodes set of OnDiskDataset."""
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
)
with tempfile.TemporaryDirectory() as test_dir:
graph_path = os.path.join(test_dir, "csc_sampling_graph.tar")
gb.save_csc_sampling_graph(graph, graph_path)
yaml_content = f"""
graph_topology:
type: CSCSamplingGraph
path: {graph_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSetDict)
for i, item in enumerate(all_nodes_set):
assert len(item) == 1
assert isinstance(item, dict)
dataset = None
def test_BuiltinDataset():
"""Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment