Unverified Commit 47d37e91 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert TVT from list of list to list (#6080)

parent 12ade95c
"""GraphBolt Dataset."""
from typing import Dict, List
from typing import Dict
from .feature_store import FeatureStore
from .itemset import ItemSet, ItemSetDict
......@@ -32,18 +32,18 @@ class Dataset:
"""
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the training sets."""
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
raise NotImplementedError
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the validation sets."""
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
raise NotImplementedError
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
"""Return the test sets."""
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
raise NotImplementedError
@property
......
......@@ -165,45 +165,42 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
)
# 7. Save the train/val/test split according to the output_config.
for set_name in ["train_sets", "validation_sets", "test_sets"]:
for set_name in ["train_set", "validation_set", "test_set"]:
if set_name not in input_config:
continue
for intput_set_split, output_set_split in zip(
for input_set_per_type, output_set_per_type in zip(
input_config[set_name], output_config[set_name]
):
for input_set_per_type, output_set_per_type in zip(
intput_set_split, output_set_split
for input_data, output_data in zip(
input_set_per_type["data"], output_set_per_type["data"]
):
for input_data, output_data in zip(
input_set_per_type["data"], output_set_per_type["data"]
):
# Always save the feature in numpy format.
output_data["format"] = "numpy"
output_data["path"] = str(
processed_dir_prefix
/ input_data["path"].replace("pt", "npy")
# Always save the feature in numpy format.
output_data["format"] = "numpy"
output_data["path"] = str(
processed_dir_prefix
/ input_data["path"].replace("pt", "npy")
)
if input_data["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_data["path"],
output_set_per_type["format"],
)
if input_data["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_data["path"],
output_set_per_type["format"],
)
# 8. Save the output_config.
output_config_path = dataset_path / "output_config.yaml"
......@@ -245,27 +242,27 @@ class OnDiskDataset(Dataset):
format: numpy
in_memory: false
path: edge_data/author-writes-paper-feat.npy
train_sets:
- - type: paper # could be null for homogeneous graph.
data: # multiple data sources could be specified.
- format: numpy
in_memory: true # If not specified, default to true.
path: set/paper-train-src.npy
- format: numpy
in_memory: false
path: set/paper-train-dst.npy
validation_sets:
- - type: paper
data:
- format: numpy
in_memory: true
path: set/paper-validation.npy
test_sets:
- - type: paper
data:
- format: numpy
in_memory: true
path: set/paper-test.npy
train_set:
- type: paper # could be null for homogeneous graph.
data: # multiple data sources could be specified.
- format: numpy
in_memory: true # If not specified, default to true.
path: set/paper-train-src.npy
- format: numpy
in_memory: false
path: set/paper-train-dst.npy
validation_set:
- type: paper
data:
- format: numpy
in_memory: true
path: set/paper-validation.npy
test_set:
- type: paper
data:
- format: numpy
in_memory: true
path: set/paper-test.npy
Parameters
----------
......@@ -285,24 +282,24 @@ class OnDiskDataset(Dataset):
self._num_labels = self._meta.num_labels
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = load_feature_stores(self._meta.feature_data)
self._train_sets = self._init_tvt_sets(self._meta.train_sets)
self._validation_sets = self._init_tvt_sets(self._meta.validation_sets)
self._test_sets = self._init_tvt_sets(self._meta.test_sets)
self._train_set = self._init_tvt_set(self._meta.train_set)
self._validation_set = self._init_tvt_set(self._meta.validation_set)
self._test_set = self._init_tvt_set(self._meta.test_set)
@property
def train_sets(self) -> List[ItemSet] or List[ItemSetDict]:
def train_set(self) -> ItemSet or ItemSetDict:
"""Return the training set."""
return self._train_sets
return self._train_set
@property
def validation_sets(self) -> List[ItemSet] or List[ItemSetDict]:
def validation_set(self) -> ItemSet or ItemSetDict:
"""Return the validation set."""
return self._validation_sets
return self._validation_set
@property
def test_sets(self) -> List[ItemSet] or List[ItemSetDict]:
def test_set(self) -> ItemSet or ItemSetDict:
"""Return the test set."""
return self._test_sets
return self._test_set
@property
def graph(self) -> object:
......@@ -341,36 +338,31 @@ class OnDiskDataset(Dataset):
f"Graph topology type {graph_topology.type} is not supported."
)
def _init_tvt_sets(
self, tvt_sets: List[List[OnDiskTVTSet]]
) -> List[ItemSet] or List[ItemSetDict]:
"""Initialize the TVT sets."""
if (tvt_sets is None) or (len(tvt_sets) == 0):
return None
ret = []
for tvt_set in tvt_sets:
if (tvt_set is None) or (len(tvt_set) == 0):
ret.append(None)
if tvt_set[0].type is None:
assert (
len(tvt_set) == 1
), "Only one TVT set is allowed if type is not specified."
ret.append(
ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt_set[0].data
)
)
def _init_tvt_set(
self, tvt_set: List[OnDiskTVTSet]
) -> ItemSet or ItemSetDict:
"""Initialize the TVT set."""
ret = None
if (tvt_set is None) or (len(tvt_set) == 0):
return ret
if tvt_set[0].type is None:
assert (
len(tvt_set) == 1
), "Only one TVT set is allowed if type is not specified."
ret = ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt_set[0].data
)
else:
data = {}
for tvt in tvt_set:
data[tvt.type] = ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt.data
)
)
else:
data = {}
for tvt in tvt_set:
data[tvt.type] = ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt.data
)
ret.append(ItemSetDict(data))
)
ret = ItemSetDict(data)
return ret
......@@ -83,6 +83,6 @@ class OnDiskMetaData(pydantic.BaseModel):
num_labels: Optional[int] = None
graph_topology: Optional[OnDiskGraphTopology] = None
feature_data: Optional[List[OnDiskFeatureData]] = []
train_sets: Optional[List[List[OnDiskTVTSet]]] = []
validation_sets: Optional[List[List[OnDiskTVTSet]]] = []
test_sets: Optional[List[List[OnDiskTVTSet]]] = []
train_set: Optional[List[OnDiskTVTSet]] = []
validation_set: Optional[List[OnDiskTVTSet]] = []
test_set: Optional[List[OnDiskTVTSet]] = []
import os
import tempfile
import numpy as np
import pydantic
import pytest
from dgl import graphbolt as gb
......@@ -11,15 +5,15 @@ from dgl import graphbolt as gb
def test_Dataset():
dataset = gb.Dataset()
with pytest.raises(NotImplementedError):
_ = dataset.train_sets()
_ = dataset.train_set
with pytest.raises(NotImplementedError):
_ = dataset.validation_sets()
_ = dataset.validation_set
with pytest.raises(NotImplementedError):
_ = dataset.test_sets()
_ = dataset.test_set
with pytest.raises(NotImplementedError):
_ = dataset.graph()
_ = dataset.graph
with pytest.raises(NotImplementedError):
_ = dataset.feature()
_ = dataset.feature
with pytest.raises(NotImplementedError):
_ = dataset.dataset_name
with pytest.raises(NotImplementedError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment