Unverified Commit fe96f11f authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable to download built-in dataset from s3 (#6329)

parent c036222b
......@@ -11,6 +11,7 @@ import yaml
import dgl
from ...data.utils import download, extract_archive
from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
......@@ -29,7 +30,7 @@ from .ondisk_metadata import (
)
from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]
def _copy_or_convert_data(
......@@ -473,3 +474,31 @@ class OnDiskDataset(Dataset):
)
ret = ItemSetDict(data)
return ret
class BuiltinDataset(OnDiskDataset):
"""GraphBolt builtin on-disk dataset.
This class is used to help download datasets from DGL S3 storage and load
them as ``OnDiskDataset``.
Parameters
----------
name : str
The name of the builtin dataset.
root : str, optional
The root directory of the dataset. Default ot ``datasets``.
"""
_base_url = "https://data.dgl.ai/dataset/graphbolt/"
def __init__(self, name: str, root: str = "datasets") -> OnDiskDataset:
dataset_dir = os.path.join(root, name)
if not os.path.exists(dataset_dir):
url = self._base_url + name + ".zip"
os.makedirs(root, exist_ok=True)
zip_file_path = os.path.join(root, name + ".zip")
download(url, path=zip_file_path)
extract_archive(zip_file_path, root, overwrite=True)
os.remove(zip_file_path)
super().__init__(dataset_dir)
......@@ -1694,3 +1694,22 @@ def test_OnDiskDataset_load_tasks():
original_train_set = None
modify_train_set = None
dataset = None
def test_BuiltinDataset():
"""Test BuiltinDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# Case 1: download from DGL S3 storage.
dataset_name = "test-only"
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
assert dataset.graph is not None
assert dataset.feature is not None
assert dataset.tasks is not None
assert dataset.dataset_name == dataset_name
# Case 2: dataset is already downloaded.
dataset = gb.BuiltinDataset(name=dataset_name, root=test_dir).load()
assert dataset.graph is not None
assert dataset.feature is not None
assert dataset.tasks is not None
assert dataset.dataset_name == dataset_name
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment