"vscode:/vscode.git/clone" did not exist on "0905d2e270853d1de11ab93adefe067606c3287c"
Unverified Commit 39890c0c authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add pydantic-based metadata for TVT (#5942)

parent 55af15d4
"""GraphBolt Dataset."""
from typing import List, Optional
import pydantic
import pydantic_yaml
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
__all__ = ["Dataset"]
__all__ = ["Dataset", "OnDiskDataset"]
class Dataset:
......@@ -48,3 +53,87 @@ class Dataset:
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
class OnDiskDataFormatEnum(pydantic_yaml.YamlStrEnum):
"""Enum of data format."""
TORCH = "torch"
NUMPY = "numpy"
class OnDiskTVTSet(pydantic.BaseModel):
"""Train-Validation-Test set."""
type_name: str
format: OnDiskDataFormatEnum
path: str
class OnDiskMetaData(pydantic_yaml.YamlModel):
"""Metadata specification in YAML.
As multiple node/edge types and multiple splits are supported, each TVT set
is a list of list of ``OnDiskTVTSet``.
"""
train_set: Optional[List[List[OnDiskTVTSet]]]
validation_set: Optional[List[List[OnDiskTVTSet]]]
test_set: Optional[List[List[OnDiskTVTSet]]]
class OnDiskDataset(Dataset):
"""An on-disk dataset.
An on-disk dataset is a dataset which reads graph topology, feature data
and TVT set from disk. Due to limited resources, the data which are too
large to fit into RAM will remain on disk while others reside in RAM once
``OnDiskDataset`` is initialized. This behavior could be controled by user
via ``in_memory`` field in YAML file.
A full example of YAML file is as follows:
.. code-block:: yaml
train_set:
- - type_name: paper
format: numpy
path: set/paper-train.npy
validation_set:
- - type_name: paper
format: numpy
path: set/paper-validation.npy
test_set:
- - type_name: paper
format: numpy
path: set/paper-test.npy
Parameters
----------
path: str
The YAML file path.
"""
def __init__(self, path: str) -> None:
with open(path, "r") as f:
self._meta = OnDiskMetaData.parse_raw(f.read(), proto="yaml")
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
raise NotImplementedError
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
raise NotImplementedError
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
raise NotImplementedError
def graph(self) -> object:
"""Return the graph."""
raise NotImplementedError
def feature(self) -> FeatureStore:
"""Return the feature."""
raise NotImplementedError
......@@ -19,6 +19,7 @@ dependencies:
- psutil
- pyarrow
- pydantic
- pydantic-yaml
- pytest
- pyyaml
- rdflib
......@@ -40,5 +41,6 @@ dependencies:
- pillow
- seaborn
- jupyter_http_over_ws
- ufmt
variables:
DGL_HOME: __DGL_HOME__
import os
import tempfile
import pydantic
import pytest
from dgl import graphbolt as gb
......@@ -14,3 +18,36 @@ def test_Dataset():
_ = dataset.graph()
with pytest.raises(NotImplementedError):
_ = dataset.feature()
def test_OnDiskDataset_TVTSet():
"""Test OnDiskDataset with TVTSet."""
with tempfile.TemporaryDirectory() as test_dir:
yaml_content = """
train_set:
- - type_name: paper
format: torch
path: set/paper-train.pt
- type_name: 'paper:cites:paper'
format: numpy
path: set/cites-train.pt
"""
yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
_ = gb.OnDiskDataset(yaml_file)
# Invalid format.
yaml_content = """
train_set:
- - type_name: paper
format: torch_invalid
path: set/paper-train.pt
- type_name: 'paper:cites:paper'
format: numpy_invalid
path: set/cites-train.pt
"""
with open(yaml_file, "w") as f:
f.write(yaml_content)
with pytest.raises(pydantic.ValidationError):
_ = gb.OnDiskDataset(yaml_file)
......@@ -34,6 +34,9 @@ export PYTHONUNBUFFERED=1
export OMP_NUM_THREADS=1
export DMLC_LOG_DEBUG=1
# Install required dependencies
python3 -m pip install pydantic-yaml
python3 -m pytest -v --capture=tee-sys --junitxml=pytest_distributed.xml --durations=100 tests/distributed/*.py || fail "distributed"
PYTHONPATH=tools:tools/distpartitioning:$PYTHONPATH python3 -m pytest -v --capture=tee-sys --junitxml=pytest_tools.xml --durations=100 tests/tools/*.py || fail "tools"
......@@ -14,7 +14,7 @@ SET DGLBACKEND=!BACKEND!
SET DGL_LIBRARY_PATH=!CD!\build
SET DGL_DOWNLOAD_DIR=!CD!
python -m pip install pytest psutil pandas pyyaml pydantic rdflib torchmetrics || EXIT /B 1
python -m pip install pytest psutil pandas pyyaml pydantic pydantic-yaml rdflib torchmetrics || EXIT /B 1
python -m pytest -v --junitxml=pytest_backend.xml --durations=100 tests\python\!DGLBACKEND! || EXIT /B 1
python -m pytest -v --junitxml=pytest_common.xml --durations=100 tests\python\common || EXIT /B 1
ENDLOCAL
......
......@@ -33,6 +33,9 @@ fi
conda activate ${DGLBACKEND}-ci
# Install required dependencies
python3 -m pip install pydantic-yaml
if [ $DGLBACKEND == "mxnet" ]
then
python3 -m pytest -v --junitxml=pytest_compute.xml --durations=100 --ignore=tests/python/common/test_ffi.py tests/python/common || fail "common"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment