Unverified Commit 6b047e4d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] wrap tvt with task structure (#6112)

parent a2234d60
"""GraphBolt Dataset.""" """GraphBolt Dataset."""
from typing import Dict, List
from .feature_store import FeatureStore from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
__all__ = ["Dataset"] __all__ = [
"Task",
"Dataset",
]
class Dataset: class Task:
"""An abstract dataset. """An abstract task.
Dataset provides abstraction for accessing the data required for training. Task consists of several meta information and *Train-Validation-Test Set*.
The data abstraction could be a native CPU memory block, a shared memory
block, a file handle of an opened file on disk, a service that provides *meta information*:
the API to access the data e.t.c. There are 3 primary components in the The meta information of a task includes any kinds of data that are defined
dataset: *Train-Validation-Test Set*, *Feature Storage*, *Graph Topology*. by the user in YAML when instantiating the task.
*Train-Validation-Test Set*: *Train-Validation-Test Set*:
The training-validation-testing (TVT) set which is used to train the neural The training-validation-testing (TVT) set which is used to train the neural
networks. We calculate the embeddings based on their respective features networks. We calculate the embeddings based on their respective features
and the graph structure, and then utilize the embeddings to optimize the and the graph structure, and then utilize the embeddings to optimize the
neural network parameters. neural network parameters.
*Feature Storage*:
A key-value store which stores node/edge/graph features.
*Graph Topology*:
Graph topology is used by the subgraph sampling algorithm to
generate a subgraph.
""" """
@property
def metadata(self) -> Dict:
"""Return the task metadata."""
raise NotImplementedError
@property @property
def train_set(self) -> ItemSet or ItemSetDict: def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set.""" """Return the training set."""
...@@ -44,6 +46,33 @@ class Dataset: ...@@ -44,6 +46,33 @@ class Dataset:
"""Return the test set.""" """Return the test set."""
raise NotImplementedError raise NotImplementedError
class Dataset:
"""An abstract dataset.
Dataset provides abstraction for accessing the data required for training.
The data abstraction could be a native CPU memory block, a shared memory
block, a file handle of an opened file on disk, a service that provides
the API to access the data e.t.c. There are 3 primary components in the
dataset: *Task*, *Feature Storage*, *Graph Topology*.
*Task*:
A task consists of several meta information and the
*Train-Validation-Test Set*. A dataset could have multiple tasks.
*Feature Storage*:
A key-value store which stores node/edge/graph features.
*Graph Topology*:
Graph topology is used by the subgraph sampling algorithm to
generate a subgraph.
"""
@property
def tasks(self) -> List[Task]:
"""Return the tasks."""
raise NotImplementedError
@property @property
def graph(self) -> object: def graph(self) -> object:
"""Return the graph.""" """Return the graph."""
...@@ -58,13 +87,3 @@ class Dataset: ...@@ -58,13 +87,3 @@ class Dataset:
def dataset_name(self) -> str: def dataset_name(self) -> str:
"""Return the dataset name.""" """Return the dataset name."""
raise NotImplementedError raise NotImplementedError
@property
def num_classes(self) -> int:
"""Return the number of classes."""
raise NotImplementedError
@property
def num_labels(self) -> int:
"""Return the number of labels."""
raise NotImplementedError
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import shutil import shutil
from copy import deepcopy from copy import deepcopy
from typing import List from typing import Dict, List
import pandas as pd import pandas as pd
import torch import torch
...@@ -12,7 +12,7 @@ import yaml ...@@ -12,7 +12,7 @@ import yaml
import dgl import dgl
from ..dataset import Dataset from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, save_data from ..utils import read_data, save_data
...@@ -22,7 +22,12 @@ from .csc_sampling_graph import ( ...@@ -22,7 +22,12 @@ from .csc_sampling_graph import (
load_csc_sampling_graph, load_csc_sampling_graph,
save_csc_sampling_graph, save_csc_sampling_graph,
) )
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet from .ondisk_metadata import (
OnDiskGraphTopology,
OnDiskMetaData,
OnDiskTaskData,
OnDiskTVTSet,
)
from .torch_based_feature_store import TorchBasedFeatureStore from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"] __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
...@@ -178,12 +183,16 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str: ...@@ -178,12 +183,16 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
feature["in_memory"], feature["in_memory"],
) )
# 7. Save the train/val/test split according to the output_config. # 7. Save tasks and train/val/test split according to the output_config.
if input_config.get("task", None):
for input_task, output_task in zip(
input_config["task"], output_config["task"]
):
for set_name in ["train_set", "validation_set", "test_set"]: for set_name in ["train_set", "validation_set", "test_set"]:
if set_name not in input_config: if set_name not in input_task:
continue continue
for input_set_per_type, output_set_per_type in zip( for input_set_per_type, output_set_per_type in zip(
input_config[set_name], output_config[set_name] input_task[set_name], output_task[set_name]
): ):
for input_data, output_data in zip( for input_data, output_data in zip(
input_set_per_type["data"], output_set_per_type["data"] input_set_per_type["data"], output_set_per_type["data"]
...@@ -211,6 +220,59 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str: ...@@ -211,6 +220,59 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
return output_config_path return output_config_path
class OnDiskTask:
"""An on-disk task.
An on-disk task is for ``OnDiskDataset``. It contains the metadata and the
train/val/test sets.
"""
def __init__(
self,
metadata: Dict,
train_set: ItemSet or ItemSetDict,
validation_set: ItemSet or ItemSetDict,
test_set: ItemSet or ItemSetDict,
):
"""Initialize a task.
Parameters
----------
metadata : Dict
Metadata.
train_set : ItemSet or ItemSetDict
Training set.
validation_set : ItemSet or ItemSetDict
Validation set.
test_set : ItemSet or ItemSetDict
Test set.
"""
self._metadata = metadata
self._train_set = train_set
self._validation_set = validation_set
self._test_set = test_set
@property
def metadata(self) -> Dict:
"""Return the task metadata."""
return self._metadata
@property
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
return self._train_set
@property
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
return self._validation_set
@property
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
return self._test_set
class OnDiskDataset(Dataset): class OnDiskDataset(Dataset):
"""An on-disk dataset. """An on-disk dataset.
...@@ -225,8 +287,6 @@ class OnDiskDataset(Dataset): ...@@ -225,8 +287,6 @@ class OnDiskDataset(Dataset):
.. code-block:: yaml .. code-block:: yaml
dataset_name: graphbolt_test dataset_name: graphbolt_test
num_classes: 10
num_labels: 10
graph_topology: graph_topology:
type: CSCSamplingGraph type: CSCSamplingGraph
path: graph_topology/csc_sampling_graph.tar path: graph_topology/csc_sampling_graph.tar
...@@ -243,6 +303,9 @@ class OnDiskDataset(Dataset): ...@@ -243,6 +303,9 @@ class OnDiskDataset(Dataset):
format: numpy format: numpy
in_memory: false in_memory: false
path: edge_data/author-writes-paper-feat.npy path: edge_data/author-writes-paper-feat.npy
tasks:
- name: "edge_classification"
num_classes: 10
train_set: train_set:
- type: paper # could be null for homogeneous graph. - type: paper # could be null for homogeneous graph.
data: # multiple data sources could be specified. data: # multiple data sources could be specified.
...@@ -279,28 +342,14 @@ class OnDiskDataset(Dataset): ...@@ -279,28 +342,14 @@ class OnDiskDataset(Dataset):
yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._meta = OnDiskMetaData(**yaml_data) self._meta = OnDiskMetaData(**yaml_data)
self._dataset_name = self._meta.dataset_name self._dataset_name = self._meta.dataset_name
self._num_classes = self._meta.num_classes
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology) self._graph = self._load_graph(self._meta.graph_topology)
self._feature = TorchBasedFeatureStore(self._meta.feature_data) self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._train_set = self._init_tvt_set(self._meta.train_set) self._tasks = self._init_tasks(self._meta.tasks)
self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set)
@property @property
def train_set(self) -> ItemSet or ItemSetDict: def tasks(self) -> List[Task]:
"""Return the training set.""" """Return the tasks."""
return self._train_set return self._tasks
@property
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
return self._validation_set
@property
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
return self._test_set
@property @property
def graph(self) -> object: def graph(self) -> object:
...@@ -317,15 +366,21 @@ class OnDiskDataset(Dataset): ...@@ -317,15 +366,21 @@ class OnDiskDataset(Dataset):
"""Return the dataset name.""" """Return the dataset name."""
return self._dataset_name return self._dataset_name
@property def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
def num_classes(self) -> int: """Initialize the tasks."""
"""Return the number of classes.""" ret = []
return self._num_classes if tasks is None:
return ret
@property for task in tasks:
def num_labels(self) -> int: ret.append(
"""Return the number of labels.""" OnDiskTask(
return self._num_labels task.extra_fields,
self._init_tvt_set(task.train_set),
self._init_tvt_set(task.validation_set),
self._init_tvt_set(task.test_set),
)
)
return ret
def _load_graph( def _load_graph(
self, graph_topology: OnDiskGraphTopology self, graph_topology: OnDiskGraphTopology
......
"""Ondisk metadata of GraphBolt.""" """Ondisk metadata of GraphBolt."""
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Any, Dict, List, Optional
import pydantic import pydantic
__all__ = [ __all__ = [
"OnDiskFeatureDataFormat", "OnDiskFeatureDataFormat",
"OnDiskTVTSetData", "OnDiskTVTSetData",
...@@ -15,6 +14,7 @@ __all__ = [ ...@@ -15,6 +14,7 @@ __all__ = [
"OnDiskMetaData", "OnDiskMetaData",
"OnDiskGraphTopologyType", "OnDiskGraphTopologyType",
"OnDiskGraphTopology", "OnDiskGraphTopology",
"OnDiskTaskData",
] ]
...@@ -71,6 +71,25 @@ class OnDiskGraphTopology(pydantic.BaseModel): ...@@ -71,6 +71,25 @@ class OnDiskGraphTopology(pydantic.BaseModel):
path: str path: str
class OnDiskTaskData(pydantic.BaseModel, extra="allow"):
"""Task specification in YAML."""
train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
extra_fields: Optional[Dict[str, Any]] = {}
@pydantic.model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra fields."""
for key in list(values.keys()):
if key not in cls.model_fields:
values["extra_fields"] = values.get("extra_fields", {})
values["extra_fields"][key] = values.pop(key)
return values
class OnDiskMetaData(pydantic.BaseModel): class OnDiskMetaData(pydantic.BaseModel):
"""Metadata specification in YAML. """Metadata specification in YAML.
...@@ -79,10 +98,6 @@ class OnDiskMetaData(pydantic.BaseModel): ...@@ -79,10 +98,6 @@ class OnDiskMetaData(pydantic.BaseModel):
""" """
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
num_classes: Optional[int] = None
num_labels: Optional[int] = None
graph_topology: Optional[OnDiskGraphTopology] = None graph_topology: Optional[OnDiskGraphTopology] = None
feature_data: Optional[List[OnDiskFeatureData]] = [] feature_data: Optional[List[OnDiskFeatureData]] = []
train_set: Optional[List[OnDiskTVTSet]] = [] tasks: Optional[List[OnDiskTaskData]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
...@@ -5,18 +5,10 @@ from dgl import graphbolt as gb ...@@ -5,18 +5,10 @@ from dgl import graphbolt as gb
def test_Dataset(): def test_Dataset():
dataset = gb.Dataset() dataset = gb.Dataset()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.train_set _ = dataset.tasks
with pytest.raises(NotImplementedError):
_ = dataset.validation_set
with pytest.raises(NotImplementedError):
_ = dataset.test_set
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.graph _ = dataset.graph
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.feature _ = dataset.feature
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
_ = dataset.dataset_name _ = dataset.dataset_name
with pytest.raises(NotImplementedError):
_ = dataset.num_classes
with pytest.raises(NotImplementedError):
_ = dataset.num_labels
...@@ -22,6 +22,8 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -22,6 +22,8 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 1: ``format`` is invalid. # Case 1: ``format`` is invalid.
yaml_content = """ yaml_content = """
tasks:
- name: node_classification
train_set: train_set:
- type: paper - type: paper
data: data:
...@@ -37,6 +39,8 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -37,6 +39,8 @@ def test_OnDiskDataset_TVTSet_exceptions():
# Case 2: ``type`` is not specified while multiple TVT sets are # Case 2: ``type`` is not specified while multiple TVT sets are
# specified. # specified.
yaml_content = """ yaml_content = """
tasks:
- name: node_classification
train_set: train_set:
- type: null - type: null
data: data:
...@@ -85,6 +89,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -85,6 +89,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
# ``type`` is not specified or specified as ``null``. # ``type`` is not specified or specified as ``null``.
# ``in_memory`` could be ``true`` and ``false``. # ``in_memory`` could be ``true`` and ``false``.
yaml_content = f""" yaml_content = f"""
tasks:
- name: node_classification
num_classes: 10
train_set: train_set:
- type: null - type: null
data: data:
...@@ -119,8 +126,13 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -119,8 +126,13 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
# Verify tasks.
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "node_classification"
assert dataset.tasks[0].metadata["num_classes"] == 10
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000 assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet) assert isinstance(train_set, gb.ItemSet)
for i, (id, label) in enumerate(train_set): for i, (id, label) in enumerate(train_set):
...@@ -129,7 +141,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -129,7 +141,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000 assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet) assert isinstance(validation_set, gb.ItemSet)
for i, (id, label) in enumerate(validation_set): for i, (id, label) in enumerate(validation_set):
...@@ -138,7 +150,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -138,7 +150,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000 assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet) assert isinstance(test_set, gb.ItemSet)
for i, (id, label) in enumerate(test_set): for i, (id, label) in enumerate(test_set):
...@@ -149,6 +161,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -149,6 +161,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
# Case 2: Some TVT sets are None. # Case 2: Some TVT sets are None.
yaml_content = f""" yaml_content = f"""
tasks:
- name: node_classification
train_set: train_set:
- type: null - type: null
data: data:
...@@ -160,9 +174,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -160,9 +174,9 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
assert dataset.train_set is not None assert dataset.tasks[0].train_set is not None
assert dataset.validation_set is None assert dataset.tasks[0].validation_set is None
assert dataset.test_set is None assert dataset.tasks[0].test_set is None
dataset = None dataset = None
...@@ -200,6 +214,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -200,6 +214,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
np.save(test_labels_path, test_labels) np.save(test_labels_path, test_labels)
yaml_content = f""" yaml_content = f"""
tasks:
- name: link_prediction
train_set: train_set:
- type: null - type: null
data: data:
...@@ -244,7 +260,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -244,7 +260,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000 assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet) assert isinstance(train_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(train_set): for i, (src, dst, label) in enumerate(train_set):
...@@ -254,7 +270,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -254,7 +270,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000 assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet) assert isinstance(validation_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(validation_set): for i, (src, dst, label) in enumerate(validation_set):
...@@ -264,7 +280,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -264,7 +280,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000 assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet) assert isinstance(test_set, gb.ItemSet)
for i, (src, dst, label) in enumerate(test_set): for i, (src, dst, label) in enumerate(test_set):
...@@ -313,6 +329,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -313,6 +329,8 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
np.save(test_neg_dst_path, test_neg_dst) np.save(test_neg_dst_path, test_neg_dst)
yaml_content = f""" yaml_content = f"""
tasks:
- name: link_prediction
train_set: train_set:
- type: null - type: null
data: data:
...@@ -357,7 +375,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -357,7 +375,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000 assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet) assert isinstance(train_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(train_set): for i, (src, dst, negs) in enumerate(train_set):
...@@ -367,7 +385,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -367,7 +385,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000 assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet) assert isinstance(validation_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(validation_set): for i, (src, dst, negs) in enumerate(validation_set):
...@@ -377,7 +395,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -377,7 +395,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000 assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet) assert isinstance(test_set, gb.ItemSet)
for i, (src, dst, negs) in enumerate(test_set): for i, (src, dst, negs) in enumerate(test_set):
...@@ -410,6 +428,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -410,6 +428,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
np.save(test_path, test_data) np.save(test_path, test_data)
yaml_content = f""" yaml_content = f"""
tasks:
- name: node_classification
train_set: train_set:
- type: paper - type: paper
data: data:
...@@ -448,7 +468,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -448,7 +468,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 2000 assert len(train_set) == 2000
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(train_set, gb.ItemSetDict)
for i, item in enumerate(train_set): for i, item in enumerate(train_set):
...@@ -462,7 +482,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -462,7 +482,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 2000 assert len(validation_set) == 2000
assert isinstance(validation_set, gb.ItemSetDict) assert isinstance(validation_set, gb.ItemSetDict)
for i, item in enumerate(validation_set): for i, item in enumerate(validation_set):
...@@ -476,7 +496,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -476,7 +496,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 2000 assert len(test_set) == 2000
assert isinstance(test_set, gb.ItemSetDict) assert isinstance(test_set, gb.ItemSetDict)
for i, item in enumerate(test_set): for i, item in enumerate(test_set):
...@@ -513,6 +533,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -513,6 +533,8 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
np.save(test_path, test_data) np.save(test_path, test_data)
yaml_content = f""" yaml_content = f"""
tasks:
- name: edge_classification
train_set: train_set:
- type: paper - type: paper
data: data:
...@@ -551,7 +573,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -551,7 +573,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 2000 assert len(train_set) == 2000
assert isinstance(train_set, gb.ItemSetDict) assert isinstance(train_set, gb.ItemSetDict)
for i, item in enumerate(train_set): for i, item in enumerate(train_set):
...@@ -566,7 +588,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -566,7 +588,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 2000 assert len(validation_set) == 2000
assert isinstance(validation_set, gb.ItemSetDict) assert isinstance(validation_set, gb.ItemSetDict)
for i, item in enumerate(validation_set): for i, item in enumerate(validation_set):
...@@ -581,7 +603,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -581,7 +603,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 2000 assert len(test_set) == 2000
assert isinstance(test_set, gb.ItemSetDict) assert isinstance(test_set, gb.ItemSetDict)
for i, item in enumerate(test_set): for i, item in enumerate(test_set):
...@@ -860,12 +882,8 @@ def test_OnDiskDataset_Metadata(): ...@@ -860,12 +882,8 @@ def test_OnDiskDataset_Metadata():
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified. # All metadata fields are specified.
dataset_name = "graphbolt_test" dataset_name = "graphbolt_test"
num_classes = 10
num_labels = 9
yaml_content = f""" yaml_content = f"""
dataset_name: {dataset_name} dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
""" """
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True) os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml") yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
...@@ -874,8 +892,6 @@ def test_OnDiskDataset_Metadata(): ...@@ -874,8 +892,6 @@ def test_OnDiskDataset_Metadata():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
assert dataset.num_classes == num_classes
assert dataset.num_labels == num_labels
# Only dataset_name is specified. # Only dataset_name is specified.
yaml_content = f""" yaml_content = f"""
...@@ -887,8 +903,6 @@ def test_OnDiskDataset_Metadata(): ...@@ -887,8 +903,6 @@ def test_OnDiskDataset_Metadata():
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
assert dataset.num_classes is None
assert dataset.num_labels is None
def test_OnDiskDataset_preprocess_homogeneous(): def test_OnDiskDataset_preprocess_homogeneous():
...@@ -899,7 +913,6 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -899,7 +913,6 @@ def test_OnDiskDataset_preprocess_homogeneous():
num_nodes = 4000 num_nodes = 4000
num_edges = 20000 num_edges = 20000
num_classes = 10 num_classes = 10
num_labels = 9
# Generate random edges. # Generate random edges.
nodes = np.repeat(np.arange(num_nodes), 5) nodes = np.repeat(np.arange(num_nodes), 5)
...@@ -945,8 +958,6 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -945,8 +958,6 @@ def test_OnDiskDataset_preprocess_homogeneous():
yaml_content = f""" yaml_content = f"""
dataset_name: {dataset_name} dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
graph: # graph structure and required attributes. graph: # graph structure and required attributes.
nodes: nodes:
- num: {num_nodes} - num: {num_nodes}
...@@ -967,6 +978,9 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -967,6 +978,9 @@ def test_OnDiskDataset_preprocess_homogeneous():
format: numpy format: numpy
in_memory: false in_memory: false
path: data/node-feat.npy path: data/node-feat.npy
tasks:
- name: node_classification
num_classes: {num_classes}
train_set: train_set:
- type_name: null - type_name: null
data: data:
...@@ -992,8 +1006,7 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -992,8 +1006,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
processed_dataset = yaml.load(f, Loader=yaml.Loader) processed_dataset = yaml.load(f, Loader=yaml.Loader)
assert processed_dataset["dataset_name"] == dataset_name assert processed_dataset["dataset_name"] == dataset_name
assert processed_dataset["num_classes"] == num_classes assert processed_dataset["tasks"][0]["num_classes"] == num_classes
assert processed_dataset["num_labels"] == num_labels
assert "graph" not in processed_dataset assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset assert "graph_topology" in processed_dataset
...@@ -1018,12 +1031,9 @@ def test_OnDiskDataset_preprocess_path(): ...@@ -1018,12 +1031,9 @@ def test_OnDiskDataset_preprocess_path():
# All metadata fields are specified. # All metadata fields are specified.
dataset_name = "graphbolt_test" dataset_name = "graphbolt_test"
num_classes = 10 num_classes = 10
num_labels = 9
yaml_content = f""" yaml_content = f"""
dataset_name: {dataset_name} dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
""" """
yaml_file = os.path.join(test_dir, "metadata.yaml") yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment