Unverified Commit 14f396d0 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] change TVT format of OnDiskDataset (#6076)

parent 17f6c4c9
......@@ -15,7 +15,7 @@ import dgl
from ..dataset import Dataset
from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, save_data, tensor_to_tuple
from ..utils import read_data, save_data
from .csc_sampling_graph import (
CSCSamplingGraph,
......@@ -173,33 +173,35 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
):
for input_set_per_type, output_set_per_type in zip(
intput_set_split, output_set_split
):
for input_data, output_data in zip(
input_set_per_type["data"], output_set_per_type["data"]
):
# Always save the feature in numpy format.
output_set_per_type["format"] = "numpy"
output_set_per_type["path"] = str(
output_data["format"] = "numpy"
output_data["path"] = str(
processed_dir_prefix
/ input_set_per_type["path"].replace("pt", "npy")
/ input_data["path"].replace("pt", "npy")
)
if input_set_per_type["format"] == "numpy":
if input_data["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path
/ os.path.dirname(output_set_per_type["path"]),
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_set_per_type["path"],
dataset_path / output_set_per_type["path"],
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_set_per_type["path"],
input_set_per_type["format"],
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_set_per_type["path"],
dataset_path / output_data["path"],
output_set_per_type["format"],
)
......@@ -245,17 +247,23 @@ class OnDiskDataset(Dataset):
path: edge_data/author-writes-paper-feat.npy
train_sets:
- - type: paper # could be null for homogeneous graph.
format: numpy
data: # multiple data sources could be specified.
- format: numpy
in_memory: true # If not specified, default to true.
path: set/paper-train.npy
path: set/paper-train-src.npy
- format: numpy
in_memory: false
path: set/paper-train-dst.npy
validation_sets:
- - type: paper
format: numpy
data:
- format: numpy
in_memory: true
path: set/paper-validation.npy
test_sets:
- - type: paper
format: numpy
data:
- format: numpy
in_memory: true
path: set/paper-test.npy
......@@ -347,16 +355,21 @@ class OnDiskDataset(Dataset):
assert (
len(tvt_set) == 1
), "Only one TVT set is allowed if type is not specified."
data = read_data(
tvt_set[0].path, tvt_set[0].format, tvt_set[0].in_memory
ret.append(
ItemSet(
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt_set[0].data
)
)
)
ret.append(ItemSet(tensor_to_tuple(data)))
else:
data = {}
for tvt in tvt_set:
data[tvt.type] = ItemSet(
tensor_to_tuple(
read_data(tvt.path, tvt.format, tvt.in_memory)
tuple(
read_data(data.path, data.format, data.in_memory)
for data in tvt.data
)
)
ret.append(ItemSetDict(data))
......
......@@ -8,6 +8,7 @@ import pydantic
__all__ = [
"OnDiskFeatureDataFormat",
"OnDiskTVTSetData",
"OnDiskTVTSet",
"OnDiskFeatureDataDomain",
"OnDiskFeatureData",
......@@ -24,15 +25,21 @@ class OnDiskFeatureDataFormat(str, Enum):
NUMPY = "numpy"
class OnDiskTVTSet(pydantic.BaseModel):
"""Train-Validation-Test set."""
class OnDiskTVTSetData(pydantic.BaseModel):
"""Train-Validation-Test set data."""
type: Optional[str] = None
format: OnDiskFeatureDataFormat
in_memory: Optional[bool] = True
path: str
class OnDiskTVTSet(pydantic.BaseModel):
"""Train-Validation-Test set."""
type: Optional[str] = None
data: List[OnDiskTVTSetData]
class OnDiskFeatureDataDomain(str, Enum):
"""Enum of feature data domain."""
......
......@@ -45,9 +45,3 @@ def save_data(data, path, fmt):
np.save(path, data)
elif fmt == "torch":
torch.save(data, path)
def tensor_to_tuple(data):
"""Split a torch.Tensor in column-wise to a tuple."""
assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
return tuple(data.t())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment