Unverified Commit 3d2c37eb authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Support converting dgl node classification dataset to graphbolt. (#6698)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent c2134442
...@@ -3,6 +3,7 @@ from .basic_feature_store import * ...@@ -3,6 +3,7 @@ from .basic_feature_store import *
from .fused_csc_sampling_graph import * from .fused_csc_sampling_graph import *
from .gpu_cached_feature import * from .gpu_cached_feature import *
from .in_subgraph_sampler import * from .in_subgraph_sampler import *
from .legacy_dataset import *
from .neighbor_sampler import * from .neighbor_sampler import *
from .ondisk_dataset import * from .ondisk_dataset import *
from .ondisk_metadata import * from .ondisk_metadata import *
......
"""Graphbolt dataset for legacy DGLDataset."""
from typing import List, Union
from dgl.data import AsNodePredDataset, DGLDataset
from ..base import etype_tuple_to_str
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .basic_feature_store import BasicFeatureStore
from .fused_csc_sampling_graph import from_dglgraph
from .ondisk_dataset import OnDiskTask
from .torch_based_feature_store import TorchBasedFeature
class LegacyDataset(Dataset):
"""A Graphbolt dataset for legacy DGLDataset."""
def __init__(self, legacy: DGLDataset):
# Only supports single graph cases.
assert len(legacy) == 1
graph = legacy[0]
# Handle OGB Dataset.
if isinstance(graph, tuple):
graph, _ = graph
if graph.is_homogeneous:
self._init_as_homogeneous_node_pred(legacy)
else:
self._init_as_heterogeneous_node_pred(legacy)
def _init_as_heterogeneous_node_pred(self, legacy: DGLDataset):
def _init_item_set_dict(idx, labels):
item_set_dict = {}
for key in idx.keys():
item_set = ItemSet(
(idx[key], labels[key][idx[key]]),
names=("seed_nodes", "labels"),
)
item_set_dict[key] = item_set
return ItemSetDict(item_set_dict)
# OGB Dataset has the idx split.
if hasattr(legacy, "get_idx_split"):
graph, labels = legacy[0]
split_idx = legacy.get_idx_split()
# Initialize tasks.
tasks = []
metadata = {
"num_classes": legacy.num_classes,
"name": "node_classification",
}
train_set = _init_item_set_dict(split_idx["train"], labels)
validation_set = _init_item_set_dict(split_idx["valid"], labels)
test_set = _init_item_set_dict(split_idx["test"], labels)
task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task)
self._tasks = tasks
item_set_dict = {}
for ntype in graph.ntypes:
item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes")
item_set_dict[ntype] = item_set
self._all_nodes_set = ItemSetDict(item_set_dict)
features = {}
for ntype in graph.ntypes:
for name in graph.nodes[ntype].data.keys():
tensor = graph.nodes[ntype].data[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("node", ntype, name)] = TorchBasedFeature(tensor)
for etype in graph.canonical_etypes:
for name in graph.edges[etype].data.keys():
tensor = graph.edges[etype].data[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
gb_etype = etype_tuple_to_str(etype)
features[("edge", gb_etype, name)] = TorchBasedFeature(
tensor
)
self._feature = BasicFeatureStore(features)
self._graph = from_dglgraph(graph, is_homogeneous=False)
self._dataset_name = legacy.name
else:
raise NotImplementedError(
"Only support heterogeneous ogn node pred dataset"
)
def _init_as_homogeneous_node_pred(self, legacy: DGLDataset):
legacy = AsNodePredDataset(legacy)
# Initialize tasks.
tasks = []
metadata = {
"num_classes": legacy.num_classes,
"name": "node_classification",
}
train_labels = legacy[0].ndata["label"][legacy.train_idx]
validation_labels = legacy[0].ndata["label"][legacy.val_idx]
test_labels = legacy[0].ndata["label"][legacy.test_idx]
train_set = ItemSet(
(legacy.train_idx, train_labels),
names=("seed_nodes", "labels"),
)
validation_set = ItemSet(
(legacy.val_idx, validation_labels),
names=("seed_nodes", "labels"),
)
test_set = ItemSet(
(legacy.test_idx, test_labels), names=("seed_nodes", "labels")
)
task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task)
self._tasks = tasks
num_nodes = legacy[0].num_nodes()
self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes")
features = {}
for name in legacy[0].ndata.keys():
tensor = legacy[0].ndata[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("node", None, name)] = TorchBasedFeature(tensor)
for name in legacy[0].edata.keys():
tensor = legacy[0].edata[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("edge", None, name)] = TorchBasedFeature(tensor)
self._feature = BasicFeatureStore(features)
self._graph = from_dglgraph(legacy[0], is_homogeneous=True)
self._dataset_name = legacy.name
@property
def tasks(self) -> List[Task]:
"""Return the tasks."""
return self._tasks
@property
def graph(self) -> SamplingGraph:
"""Return the graph."""
return self._graph
@property
def feature(self) -> BasicFeatureStore:
"""Return the feature."""
return self._feature
@property
def dataset_name(self) -> str:
"""Return the dataset name."""
return self._dataset_name
@property
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the itemset containing all nodes."""
return self._all_nodes_set
import dgl.graphbolt as gb
import pytest
import torch
from dgl import AddSelfLoop
from dgl.data import AsNodePredDataset, CoraGraphDataset
def test_LegacyDataset_homo_node_pred():
cora = CoraGraphDataset(transform=AddSelfLoop())
dataset = gb.LegacyDataset(cora)
# Check tasks.
assert len(dataset.tasks) == 1
task = dataset.tasks[0]
assert task.train_set.names == ("seed_nodes", "labels")
assert len(task.train_set) == 140
assert task.validation_set.names == ("seed_nodes", "labels")
assert len(task.validation_set) == 500
assert task.test_set.names == ("seed_nodes", "labels")
assert len(task.test_set) == 1000
assert task.metadata["num_classes"] == 7
num_nodes = 2708
assert dataset.graph.num_nodes == num_nodes
assert len(dataset.all_nodes_set) == num_nodes
assert dataset.feature.size("node", None, "feat") == torch.Size([1433])
assert (
dataset.feature.read(
"node", None, "feat", torch.Tensor([num_nodes - 1])
).size(dim=0)
== 1
)
with pytest.raises(IndexError):
dataset.feature.read("node", None, "feat", torch.Tensor([num_nodes]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment