Unverified Commit 6b047e4d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] wrap tvt with task structure (#6112)

parent a2234d60
"""GraphBolt Dataset."""
from typing import Dict, List
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
__all__ = ["Dataset"]
__all__ = [
"Task",
"Dataset",
]
class Dataset:
"""An abstract dataset.
class Task:
"""An abstract task.
Dataset provides abstraction for accessing the data required for training.
The data abstraction could be a native CPU memory block, a shared memory
block, a file handle of an opened file on disk, a service that provides
the API to access the data e.t.c. There are 3 primary components in the
dataset: *Train-Validation-Test Set*, *Feature Storage*, *Graph Topology*.
Task consists of several meta information and *Train-Validation-Test Set*.
*meta information*:
The meta information of a task includes any kinds of data that are defined
by the user in YAML when instantiating the task.
*Train-Validation-Test Set*:
The training-validation-testing (TVT) set which is used to train the neural
networks. We calculate the embeddings based on their respective features
and the graph structure, and then utilize the embeddings to optimize the
neural network parameters.
*Feature Storage*:
A key-value store which stores node/edge/graph features.
*Graph Topology*:
Graph topology is used by the subgraph sampling algorithm to
generate a subgraph.
"""
@property
def metadata(self) -> Dict:
"""Return the task metadata."""
raise NotImplementedError
@property
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
......@@ -44,6 +46,33 @@ class Dataset:
"""Return the test set."""
raise NotImplementedError
class Dataset:
"""An abstract dataset.
Dataset provides abstraction for accessing the data required for training.
The data abstraction could be a native CPU memory block, a shared memory
block, a file handle of an opened file on disk, a service that provides
the API to access the data e.t.c. There are 3 primary components in the
dataset: *Task*, *Feature Storage*, *Graph Topology*.
*Task*:
A task consists of several meta information and the
*Train-Validation-Test Set*. A dataset could have multiple tasks.
*Feature Storage*:
A key-value store which stores node/edge/graph features.
*Graph Topology*:
Graph topology is used by the subgraph sampling algorithm to
generate a subgraph.
"""
@property
def tasks(self) -> List[Task]:
"""Return the tasks."""
raise NotImplementedError
@property
def graph(self) -> object:
"""Return the graph."""
......@@ -58,13 +87,3 @@ class Dataset:
def dataset_name(self) -> str:
"""Return the dataset name."""
raise NotImplementedError
@property
def num_classes(self) -> int:
"""Return the number of classes."""
raise NotImplementedError
@property
def num_labels(self) -> int:
"""Return the number of labels."""
raise NotImplementedError
......@@ -4,7 +4,7 @@ import os
import shutil
from copy import deepcopy
from typing import List
from typing import Dict, List
import pandas as pd
import torch
......@@ -12,7 +12,7 @@ import yaml
import dgl
from ..dataset import Dataset
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, save_data
......@@ -22,7 +22,12 @@ from .csc_sampling_graph import (
load_csc_sampling_graph,
save_csc_sampling_graph,
)
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .ondisk_metadata import (
OnDiskGraphTopology,
OnDiskMetaData,
OnDiskTaskData,
OnDiskTVTSet,
)
from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
......@@ -178,12 +183,16 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
feature["in_memory"],
)
# 7. Save the train/val/test split according to the output_config.
# 7. Save tasks and train/val/test split according to the output_config.
if input_config.get("task", None):
for input_task, output_task in zip(
input_config["task"], output_config["task"]
):
for set_name in ["train_set", "validation_set", "test_set"]:
if set_name not in input_config:
if set_name not in input_task:
continue
for input_set_per_type, output_set_per_type in zip(
input_config[set_name], output_config[set_name]
input_task[set_name], output_task[set_name]
):
for input_data, output_data in zip(
input_set_per_type["data"], output_set_per_type["data"]
......@@ -211,6 +220,59 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
return output_config_path
class OnDiskTask:
"""An on-disk task.
An on-disk task is for ``OnDiskDataset``. It contains the metadata and the
train/val/test sets.
"""
def __init__(
self,
metadata: Dict,
train_set: ItemSet or ItemSetDict,
validation_set: ItemSet or ItemSetDict,
test_set: ItemSet or ItemSetDict,
):
"""Initialize a task.
Parameters
----------
metadata : Dict
Metadata.
train_set : ItemSet or ItemSetDict
Training set.
validation_set : ItemSet or ItemSetDict
Validation set.
test_set : ItemSet or ItemSetDict
Test set.
"""
self._metadata = metadata
self._train_set = train_set
self._validation_set = validation_set
self._test_set = test_set
@property
def metadata(self) -> Dict:
"""Return the task metadata."""
return self._metadata
@property
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
return self._train_set
@property
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
return self._validation_set
@property
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
return self._test_set
class OnDiskDataset(Dataset):
"""An on-disk dataset.
......@@ -225,8 +287,6 @@ class OnDiskDataset(Dataset):
.. code-block:: yaml
dataset_name: graphbolt_test
num_classes: 10
num_labels: 10
graph_topology:
type: CSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar
......@@ -243,6 +303,9 @@ class OnDiskDataset(Dataset):
format: numpy
in_memory: false
path: edge_data/author-writes-paper-feat.npy
tasks:
- name: "edge_classification"
num_classes: 10
train_set:
- type: paper # could be null for homogeneous graph.
data: # multiple data sources could be specified.
......@@ -279,28 +342,14 @@ class OnDiskDataset(Dataset):
yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._meta = OnDiskMetaData(**yaml_data)
self._dataset_name = self._meta.dataset_name
self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set)
self._tasks = self._init_tasks(self._meta.tasks)
@property
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
return self._train_set
@property
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
return self._validation_set
@property
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
return self._test_set
def tasks(self) -> List[Task]:
"""Return the tasks."""
return self._tasks
@property
def graph(self) -> object:
......@@ -317,15 +366,21 @@ class OnDiskDataset(Dataset):
"""Return the dataset name."""
return self._dataset_name
@property
def num_classes(self) -> int:
"""Return the number of classes."""
return self._num_classes
@property
def num_labels(self) -> int:
"""Return the number of labels."""
return self._num_labels
def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
"""Initialize the tasks."""
ret = []
if tasks is None:
return ret
for task in tasks:
ret.append(
OnDiskTask(
task.extra_fields,
self._init_tvt_set(task.train_set),
self._init_tvt_set(task.validation_set),
self._init_tvt_set(task.test_set),
)
)
return ret
def _load_graph(
self, graph_topology: OnDiskGraphTopology
......
"""Ondisk metadata of GraphBolt."""
from enum import Enum
from typing import List, Optional
from typing import Any, Dict, List, Optional
import pydantic
__all__ = [
"OnDiskFeatureDataFormat",
"OnDiskTVTSetData",
......@@ -15,6 +14,7 @@ __all__ = [
"OnDiskMetaData",
"OnDiskGraphTopologyType",
"OnDiskGraphTopology",
"OnDiskTaskData",
]
......@@ -71,6 +71,25 @@ class OnDiskGraphTopology(pydantic.BaseModel):
path: str
class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
"""Task specification in YAML."""
train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {}
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel):
"""Metadata specification in YAML.
......@@ -79,10 +98,6 @@ class OnDiskMetaData(pydantic.BaseModel):
"""
dataset_name: Optional[str] = None
num_classes: Optional[int] = None
num_labels: Optional[int] = None
graph_topology: Optional[OnDiskGraphTopology] = None
feature_data: Optional[List[OnDiskFeatureData]] = []
train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
tasks: Optional[List[OnDiskTaskData]] = []
......@@ -5,18 +5,10 @@ from dgl import graphbolt as gb
def test_Dataset():
dataset = gb.Dataset()
with pytest.raises(NotImplementedError):
_ = dataset.train_set
with pytest.raises(NotImplementedError):
_ = dataset.validation_set
with pytest.raises(NotImplementedError):
_ = dataset.test_set
_ = dataset.tasks
with pytest.raises(NotImplementedError):
_ = dataset.graph
with pytest.raises(NotImplementedError):
_ = dataset.feature
with pytest.raises(NotImplementedError):
_ = dataset.dataset_name
with pytest.raises(NotImplementedError):
_ = dataset.num_classes
with pytest.raises(NotImplementedError):
_ = dataset.num_labels
......@@ -22,6 +22,8 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 1: ``format`` is invalid.
yaml_content = """
tasks:
- name: node_classification
train_set:
- type: paper
data:
......@@ -37,6 +39,8 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 2: ``type`` is not specified while multiple TVT sets are
# specified.
yaml_content = """
tasks:
- name: node_classification
train_set:
- type: null
data:
......@@ -85,6 +89,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
# ``type`` is not specified or specified as ``null``.
# ``in_memory`` could be ``true`` and ``false``.
yaml_content = f"""
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: null
data:
......@@ -119,8 +126,13 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
dataset = gb.OnDiskDataset(test_dir)
# Verify tasks.
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "node_classification"
assert dataset.tasks[0].metadata["num_classes"] == 10
# Verify train set.
train_set = dataset.train_set
train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet)
for i, (id, label) in enumerate(train_set):
......@@ -129,7 +141,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set = None
# Verify validation set.
validation_set = dataset.validation_set
validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet)
for i, (id, label) in enumerate(validation_set):
......@@ -138,7 +150,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
validation_set = None
# Verify test set.
test_set = dataset.test_set
test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet)
for i, (id, label) in enumerate(test_set):
......@@ -149,6 +161,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
# Case 2: Some TVT sets are None.
yaml_content = f"""
tasks:
- name: node_classification
train_set:
- type: null
data:
......@@ -160,9 +174,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir)
assert dataset.train_set is not None
assert dataset.validation_set is None
assert dataset.test_set is None
assert dataset.tasks[0].train_set is not None
assert dataset.tasks[0].validation_set is None
assert dataset.tasks[0].test_set is None
dataset = None
......@@ -200,6 +214,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
np.save(test_labels_path, test_labels)
yaml_content = f"""
tasks:
- name: link_prediction
train_set:
- type: null
data:
......@@ -244,7 +260,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
dataset = gb.OnDiskDataset(test_dir)
# Verify train set.
train_set = dataset.train_set
train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(train_set):
......@@ -254,7 +270,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_set = None
# Verify validation set.
validation_set = dataset.validation_set
validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(validation_set):
......@@ -264,7 +280,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
validation_set = None
# Verify test set.
test_set = dataset.test_set
test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(test_set):
......@@ -313,6 +329,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
np.save(test_neg_dst_path, test_neg_dst)
yaml_content = f"""
tasks:
- name: link_prediction
train_set:
- type: null
data:
......@@ -357,7 +375,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
dataset = gb.OnDiskDataset(test_dir)
# Verify train set.
train_set = dataset.train_set
train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(train_set):
......@@ -367,7 +385,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
train_set = None
# Verify validation set.
validation_set = dataset.validation_set
validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(validation_set):
......@@ -377,7 +395,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
validation_set = None
# Verify test set.
test_set = dataset.test_set
test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(test_set):
......@@ -410,6 +428,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
np.save(test_path, test_data)
yaml_content = f"""
tasks:
- name: node_classification
train_set:
- type: paper
data:
......@@ -448,7 +468,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
dataset = gb.OnDiskDataset(test_dir)
# Verify train set.
train_set = dataset.train_set
train_set = dataset.tasks[0].train_set
assert len(train_set) == 2000
assert isinstance(train_set, gb.ItemSetDict)
for i, item in enumerate(train_set):
......@@ -462,7 +482,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set = None
# Verify validation set.
validation_set = dataset.validation_set
validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 2000
assert isinstance(validation_set, gb.ItemSetDict)
for i, item in enumerate(validation_set):
......@@ -476,7 +496,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
validation_set = None
# Verify test set.
test_set = dataset.test_set
test_set = dataset.tasks[0].test_set
assert len(test_set) == 2000
assert isinstance(test_set, gb.ItemSetDict)
for i, item in enumerate(test_set):
......@@ -513,6 +533,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
np.save(test_path, test_data)
yaml_content = f"""
tasks:
- name: edge_classification
train_set:
- type: paper
data:
......@@ -551,7 +573,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
dataset = gb.OnDiskDataset(test_dir)
# Verify train set.
train_set = dataset.train_set
train_set = dataset.tasks[0].train_set
assert len(train_set) == 2000
assert isinstance(train_set, gb.ItemSetDict)
for i, item in enumerate(train_set):
......@@ -566,7 +588,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_set = None
# Verify validation set.
validation_set = dataset.validation_set
validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 2000
assert isinstance(validation_set, gb.ItemSetDict)
for i, item in enumerate(validation_set):
......@@ -581,7 +603,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
validation_set = None
# Verify test set.
test_set = dataset.test_set
test_set = dataset.tasks[0].test_set
assert len(test_set) == 2000
assert isinstance(test_set, gb.ItemSetDict)
for i, item in enumerate(test_set):
......@@ -860,12 +882,8 @@ def test_OnDiskDataset_Metadata():
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_classes = 10
num_labels = 9
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
......@@ -874,8 +892,6 @@ def test_OnDiskDataset_Metadata():
dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name
assert dataset.num_classes == num_classes
assert dataset.num_labels == num_labels
# Only dataset_name is specified.
yaml_content = f"""
......@@ -887,8 +903,6 @@ def test_OnDiskDataset_Metadata():
dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name
assert dataset.num_classes is None
assert dataset.num_labels is None
def test_OnDiskDataset_preprocess_homogeneous():
......@@ -899,7 +913,6 @@ def test_OnDiskDataset_preprocess_homogeneous():
num_nodes = 4000
num_edges = 20000
num_classes = 10
num_labels = 9
# Generate random edges.
nodes = np.repeat(np.arange(num_nodes), 5)
......@@ -945,8 +958,6 @@ def test_OnDiskDataset_preprocess_homogeneous():
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
graph: # graph structure and required attributes.
nodes:
- num: {num_nodes}
......@@ -967,6 +978,9 @@ def test_OnDiskDataset_preprocess_homogeneous():
format: numpy
in_memory: false
path: data/node-feat.npy
tasks:
- name: node_classification
num_classes: {num_classes}
train_set:
- type_name: null
data:
......@@ -992,8 +1006,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
processed_dataset = yaml.load(f, Loader=yaml.Loader)
assert processed_dataset["dataset_name"] == dataset_name
assert processed_dataset["num_classes"] == num_classes
assert processed_dataset["num_labels"] == num_labels
assert processed_dataset["tasks"][0]["num_classes"] == num_classes
assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset
......@@ -1018,12 +1031,9 @@ def test_OnDiskDataset_preprocess_path():
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_classes = 10
num_labels = 9
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
"""
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment