Unverified Commit 2668d62f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add abstract Dataset (#5926)

parent 1cbe0b27
......@@ -9,6 +9,7 @@ from .graph_storage import *
from .itemset import *
from .minibatch_sampler import *
from .feature_store import *
from .dataset import *
from .subgraph_sampler import *
......
"""GraphBolt Dataset."""
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
__all__ = ["Dataset"]
class Dataset:
"""An abstract dataset.
Dataset provides abstraction for accessing the data required for training.
The data abstraction could be a native CPU memory block, a shared memory
block, a file handle of an opened file on disk, a service that provides
the API to access the data e.t.c. There are 3 primary components in the
dataset: *Train-Validation-Test Set*, *Feature Storage*, *Graph Topology*.
*Train-Validation-Test Set*:
The training-validation-testing (TVT) set which is used to train the neural
networks. We calculate the embeddings based on their respective features
and the graph structure, and then utilize the embeddings to optimize the
neural network parameters.
*Feature Storage*:
A key-value store which stores node/edge/graph features.
*Graph Topology*:
Graph topology is used by the subgraph sampling algorithm to
generate a subgraph.
"""
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
raise NotImplementedError
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
raise NotImplementedError
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
raise NotImplementedError
def graph(self) -> object:
"""Return the graph."""
raise NotImplementedError
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
import pytest
from dgl import graphbolt as gb
def test_Dataset():
dataset = gb.Dataset()
with pytest.raises(NotImplementedError):
_ = dataset.train_set()
with pytest.raises(NotImplementedError):
_ = dataset.validation_set()
with pytest.raises(NotImplementedError):
_ = dataset.test_set()
with pytest.raises(NotImplementedError):
_ = dataset.graph()
with pytest.raises(NotImplementedError):
_ = dataset.feature()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment