Unverified Commit 129e75f3 authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Update `preprocess` to Accept Dataset Path Instead of YAML File Path (#6091)

parent bcb5be4a
...@@ -4,7 +4,6 @@ import os ...@@ -4,7 +4,6 @@ import os
import shutil import shutil
from copy import deepcopy from copy import deepcopy
from pathlib import Path
from typing import List from typing import List
import pandas as pd import pandas as pd
...@@ -29,46 +28,80 @@ from .torch_based_feature_store import TorchBasedFeatureStore ...@@ -29,46 +28,80 @@ from .torch_based_feature_store import TorchBasedFeatureStore
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"] __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
def preprocess_ondisk_dataset(input_config_path: str) -> str: def _copy_or_convert_data(
input_path,
output_path,
input_format,
output_format="numpy",
in_memory=True,
):
"""Copy or convert the data from input_path to output_path."""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
if input_format == "numpy":
# If the original format is numpy, just copy the file.
shutil.copyfile(input_path, output_path)
else:
# If the original format is not numpy, convert it to numpy.
data = read_data(input_path, input_format, in_memory)
save_data(data, output_path, output_format)
def preprocess_ondisk_dataset(dataset_dir: str) -> str:
"""Preprocess the on-disk dataset. Parse the input config file, """Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports. load the data, and save the data in the format that GraphBolt supports.
Parameters Parameters
---------- ----------
input_config_path : str dataset_dir : str
The path to the input config file. The path to the dataset directory.
Returns Returns
------- -------
output_config_path : str output_config_path : str
The path to the output config file. The path to the output config file.
""" """
# 0. Load the input_config. # Check if the dataset path is valid.
with open(input_config_path, "r") as f: if not os.path.exists(dataset_dir):
input_config = yaml.safe_load(f) raise RuntimeError(f"Invalid dataset path: {dataset_dir}")
# Check if the dataset_dir is a directory.
if not os.path.isdir(dataset_dir):
raise RuntimeError(
f"The dataset must be a directory. But got {dataset_dir}"
)
# If the input config does not contain the "graph" field, then we # 0. Check if the dataset is already preprocessed.
# assume that the input config is already preprocessed. if os.path.exists(os.path.join(dataset_dir, "preprocessed/metadata.yaml")):
if "graph" not in input_config: print("The dataset is already preprocessed.")
print("The input config is already preprocessed.") return os.path.join(dataset_dir, "preprocessed/metadata.yaml")
return input_config_path
print("Start to preprocess the on-disk dataset.") print("Start to preprocess the on-disk dataset.")
# Infer the dataset path from the input config path. processed_dir_prefix = os.path.join(dataset_dir, "preprocessed")
dataset_path = Path(os.path.dirname(input_config_path))
processed_dir_prefix = Path("preprocessed")
# 1. Make `processed_dir_prefix` directory if it does not exist. # Check if the metadata.yaml exists.
os.makedirs(dataset_path / processed_dir_prefix, exist_ok=True) metadata_file_path = os.path.join(dataset_dir, "metadata.yaml")
if not os.path.exists(metadata_file_path):
raise RuntimeError("metadata.yaml does not exist.")
# Read the input config.
with open(metadata_file_path, "r") as f:
input_config = yaml.safe_load(f)
# 1. Make `processed_dir_abs` directory if it does not exist.
os.makedirs(processed_dir_prefix, exist_ok=True)
output_config = deepcopy(input_config) output_config = deepcopy(input_config)
# 2. Load the edge data and create a DGLGraph. # 2. Load the edge data and create a DGLGraph.
if "graph" not in input_config:
raise RuntimeError("Invalid config: does not contain graph field.")
is_homogeneous = "type" not in input_config["graph"]["nodes"][0] is_homogeneous = "type" not in input_config["graph"]["nodes"][0]
if is_homogeneous: if is_homogeneous:
# Homogeneous graph. # Homogeneous graph.
num_nodes = input_config["graph"]["nodes"][0]["num"] num_nodes = input_config["graph"]["nodes"][0]["num"]
edge_data = pd.read_csv( edge_data = pd.read_csv(
dataset_path / input_config["graph"]["edges"][0]["path"], os.path.join(
dataset_dir, input_config["graph"]["edges"][0]["path"]
),
names=["src", "dst"], names=["src", "dst"],
) )
src, dst = edge_data["src"].to_numpy(), edge_data["dst"].to_numpy() src, dst = edge_data["src"].to_numpy(), edge_data["dst"].to_numpy()
...@@ -84,7 +117,8 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -84,7 +117,8 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
data_dict = {} data_dict = {}
for edge_info in input_config["graph"]["edges"]: for edge_info in input_config["graph"]["edges"]:
edge_data = pd.read_csv( edge_data = pd.read_csv(
dataset_path / edge_info["path"], names=["src", "dst"] os.path.join(dataset_dir, edge_info["path"]),
names=["src", "dst"],
) )
src = torch.tensor(edge_data["src"]) src = torch.tensor(edge_data["src"])
dst = torch.tensor(edge_data["dst"]) dst = torch.tensor(edge_data["dst"])
...@@ -98,14 +132,14 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -98,14 +132,14 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
for graph_feature in input_config["graph"]["feature_data"]: for graph_feature in input_config["graph"]["feature_data"]:
if graph_feature["domain"] == "node": if graph_feature["domain"] == "node":
node_data = read_data( node_data = read_data(
dataset_path / graph_feature["path"], os.path.join(dataset_dir, graph_feature["path"]),
graph_feature["format"], graph_feature["format"],
in_memory=graph_feature["in_memory"], in_memory=graph_feature["in_memory"],
) )
g.ndata[graph_feature["name"]] = node_data g.ndata[graph_feature["name"]] = node_data
if graph_feature["domain"] == "edge": if graph_feature["domain"] == "edge":
edge_data = read_data( edge_data = read_data(
dataset_path / graph_feature["path"], os.path.join(dataset_dir, graph_feature["path"]),
graph_feature["format"], graph_feature["format"],
in_memory=graph_feature["in_memory"], in_memory=graph_feature["in_memory"],
) )
...@@ -117,13 +151,12 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -117,13 +151,12 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
# 5. Save the CSCSamplingGraph and modify the output_config. # 5. Save the CSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {} output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "CSCSamplingGraph" output_config["graph_topology"]["type"] = "CSCSamplingGraph"
output_config["graph_topology"]["path"] = str( output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix / "csc_sampling_graph.tar" processed_dir_prefix, "csc_sampling_graph.tar"
) )
save_csc_sampling_graph( save_csc_sampling_graph(
csc_sampling_graph, csc_sampling_graph, output_config["graph_topology"]["path"]
str(dataset_path / output_config["graph_topology"]["path"]),
) )
del output_config["graph"] del output_config["graph"]
...@@ -134,32 +167,16 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -134,32 +167,16 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
): ):
# Always save the feature in numpy format. # Always save the feature in numpy format.
out_feature["format"] = "numpy" out_feature["format"] = "numpy"
out_feature["path"] = str( out_feature["path"] = os.path.join(
processed_dir_prefix / feature["path"].replace("pt", "npy") processed_dir_prefix, feature["path"].replace("pt", "npy")
)
_copy_or_convert_data(
os.path.join(dataset_dir, feature["path"]),
out_feature["path"],
feature["format"],
out_feature["format"],
feature["in_memory"],
) )
if feature["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(out_feature["path"]),
exist_ok=True,
)
shutil.copyfile(
dataset_path / feature["path"],
dataset_path / out_feature["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
data = read_data(
dataset_path / feature["path"],
feature["format"],
in_memory=feature["in_memory"],
)
save_data(
data,
dataset_path / out_feature["path"],
out_feature["format"],
)
# 7. Save the train/val/test split according to the output_config. # 7. Save the train/val/test split according to the output_config.
for set_name in ["train_set", "validation_set", "test_set"]: for set_name in ["train_set", "validation_set", "test_set"]:
...@@ -173,38 +190,25 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str: ...@@ -173,38 +190,25 @@ def preprocess_ondisk_dataset(input_config_path: str) -> str:
): ):
# Always save the feature in numpy format. # Always save the feature in numpy format.
output_data["format"] = "numpy" output_data["format"] = "numpy"
output_data["path"] = str( output_data["path"] = os.path.join(
processed_dir_prefix processed_dir_prefix,
/ input_data["path"].replace("pt", "npy") input_data["path"].replace("pt", "npy"),
)
_copy_or_convert_data(
os.path.join(dataset_dir, input_data["path"]),
output_data["path"],
input_data["format"],
output_data["format"],
) )
if input_data["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(output_data["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_data["path"],
dataset_path / output_data["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_data["path"],
input_data["format"],
)
save_data(
input_set,
dataset_path / output_data["path"],
output_set_per_type["format"],
)
# 8. Save the output_config. # 8. Save the output_config.
output_config_path = dataset_path / "output_config.yaml" output_config_path = os.path.join(dataset_dir, "preprocessed/metadata.yaml")
with open(output_config_path, "w") as f: with open(output_config_path, "w") as f:
yaml.dump(output_config, f) yaml.dump(output_config, f)
print("Finish preprocessing the on-disk dataset.") print("Finish preprocessing the on-disk dataset.")
return str(output_config_path)
# 9. Return the absolute path of the preprocessing yaml file.
return output_config_path
class OnDiskDataset(Dataset): class OnDiskDataset(Dataset):
......
import os import os
import re
import tempfile import tempfile
import gb_test_utils as gbt import gb_test_utils as gbt
...@@ -16,7 +17,8 @@ from dgl import graphbolt as gb ...@@ -16,7 +17,8 @@ from dgl import graphbolt as gb
def test_OnDiskDataset_TVTSet_exceptions(): def test_OnDiskDataset_TVTSet_exceptions():
"""Test excpetions thrown when parsing TVTSet.""" """Test excpetions thrown when parsing TVTSet."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
# Case 1: ``format`` is invalid. # Case 1: ``format`` is invalid.
yaml_content = """ yaml_content = """
...@@ -26,13 +28,14 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -26,13 +28,14 @@ def test_OnDiskDataset_TVTSet_exceptions():
- format: torch_invalid - format: torch_invalid
path: set/paper-train.pt path: set/paper-train.pt
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
_ = gb.OnDiskDataset(yaml_file) _ = gb.OnDiskDataset(test_dir)
# Case 2: ``type`` is not specified while multiple TVT sets are specified. # Case 2: ``type`` is not specified while multiple TVT sets are
# specified.
yaml_content = """ yaml_content = """
train_set: train_set:
- type: null - type: null
...@@ -50,7 +53,7 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -50,7 +53,7 @@ def test_OnDiskDataset_TVTSet_exceptions():
AssertionError, AssertionError,
match=r"Only one TVT set is allowed if type is not specified.", match=r"Only one TVT set is allowed if type is not specified.",
): ):
_ = gb.OnDiskDataset(yaml_file) _ = gb.OnDiskDataset(test_dir)
def test_OnDiskDataset_TVTSet_ItemSet_id_label(): def test_OnDiskDataset_TVTSet_ItemSet_id_label():
...@@ -109,11 +112,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -109,11 +112,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.train_set
...@@ -151,11 +155,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -151,11 +155,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
- format: numpy - format: numpy
path: {train_ids_path} path: {train_ids_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
assert dataset.train_set is not None assert dataset.train_set is not None
assert dataset.validation_set is None assert dataset.validation_set is None
assert dataset.test_set is None assert dataset.test_set is None
...@@ -232,11 +236,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -232,11 +236,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.train_set
...@@ -344,11 +349,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs(): ...@@ -344,11 +349,12 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_negs():
in_memory: true in_memory: true
path: {test_neg_dst_path} path: {test_neg_dst_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.train_set
...@@ -434,11 +440,12 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -434,11 +440,12 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
- format: numpy - format: numpy
path: {test_path} path: {test_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.train_set
...@@ -536,11 +543,12 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label(): ...@@ -536,11 +543,12 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pair_label():
- format: numpy - format: numpy
path: {test_path} path: {test_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify train set. # Verify train set.
train_set = dataset.train_set train_set = dataset.train_set
...@@ -636,11 +644,12 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -636,11 +644,12 @@ def test_OnDiskDataset_Feature_heterograph():
in_memory: true in_memory: true
path: {edge_data_label_path} path: {edge_data_label_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify feature data storage. # Verify feature data storage.
feature_data = dataset.feature feature_data = dataset.feature
...@@ -714,11 +723,12 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -714,11 +723,12 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: true in_memory: true
path: {edge_data_label_path} path: {edge_data_label_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
# Verify feature data storage. # Verify feature data storage.
feature_data = dataset.feature feature_data = dataset.feature
...@@ -757,7 +767,8 @@ def test_OnDiskDataset_Graph_Exceptions(): ...@@ -757,7 +767,8 @@ def test_OnDiskDataset_Graph_Exceptions():
type: CSRSamplingGraph type: CSRSamplingGraph
path: /path/to/graph path: /path/to/graph
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
...@@ -765,7 +776,7 @@ def test_OnDiskDataset_Graph_Exceptions(): ...@@ -765,7 +776,7 @@ def test_OnDiskDataset_Graph_Exceptions():
pydantic.ValidationError, pydantic.ValidationError,
match="1 validation error for OnDiskMetaData", match="1 validation error for OnDiskMetaData",
): ):
_ = gb.OnDiskDataset(yaml_file) _ = gb.OnDiskDataset(test_dir)
def test_OnDiskDataset_Graph_homogeneous(): def test_OnDiskDataset_Graph_homogeneous():
...@@ -782,11 +793,12 @@ def test_OnDiskDataset_Graph_homogeneous(): ...@@ -782,11 +793,12 @@ def test_OnDiskDataset_Graph_homogeneous():
type: CSCSamplingGraph type: CSCSamplingGraph
path: {graph_path} path: {graph_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
graph2 = dataset.graph graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes assert graph.num_nodes == graph2.num_nodes
...@@ -824,11 +836,12 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -824,11 +836,12 @@ def test_OnDiskDataset_Graph_heterogeneous():
type: CSCSamplingGraph type: CSCSamplingGraph
path: {graph_path} path: {graph_path}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
graph2 = dataset.graph graph2 = dataset.graph
assert graph.num_nodes == graph2.num_nodes assert graph.num_nodes == graph2.num_nodes
...@@ -854,11 +867,12 @@ def test_OnDiskDataset_Metadata(): ...@@ -854,11 +867,12 @@ def test_OnDiskDataset_Metadata():
num_classes: {num_classes} num_classes: {num_classes}
num_labels: {num_labels} num_labels: {num_labels}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
assert dataset.num_classes == num_classes assert dataset.num_classes == num_classes
assert dataset.num_labels == num_labels assert dataset.num_labels == num_labels
...@@ -867,11 +881,11 @@ def test_OnDiskDataset_Metadata(): ...@@ -867,11 +881,11 @@ def test_OnDiskDataset_Metadata():
yaml_content = f""" yaml_content = f"""
dataset_name: {dataset_name} dataset_name: {dataset_name}
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
dataset = gb.OnDiskDataset(yaml_file) dataset = gb.OnDiskDataset(test_dir)
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
assert dataset.num_classes is None assert dataset.num_classes is None
assert dataset.num_labels is None assert dataset.num_labels is None
...@@ -969,10 +983,10 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -969,10 +983,10 @@ def test_OnDiskDataset_preprocess_homogeneous():
- format: numpy - format: numpy
path: set/test.npy path: set/test.npy
""" """
yaml_file = os.path.join(test_dir, "test.yaml") yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(yaml_file) output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(test_dir)
with open(output_file, "rb") as f: with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader) processed_dataset = yaml.load(f, Loader=yaml.Loader)
...@@ -996,3 +1010,46 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -996,3 +1010,46 @@ def test_OnDiskDataset_preprocess_homogeneous():
torch.tensor([fanout]), torch.tensor([fanout]),
) )
assert len(list(subgraph.node_pairs.values())[0][0]) <= num_samples assert len(list(subgraph.node_pairs.values())[0][0]) <= num_samples
def test_OnDiskDataset_preprocess_path():
"""Test if the preprocess function can catch the path error."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_classes = 10
num_labels = 9
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
"""
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# Case1. Test the passed in is the yaml file path.
with pytest.raises(
RuntimeError,
match="The dataset must be a directory. "
rf"But got {re.escape(yaml_file)}",
):
_ = gb.OnDiskDataset(yaml_file)
# Case2. Test the passed in is a fake directory.
fake_dir = os.path.join(test_dir, "fake_dir")
with pytest.raises(
RuntimeError,
match=rf"Invalid dataset path: {re.escape(fake_dir)}",
):
_ = gb.OnDiskDataset(fake_dir)
# Case3. Test the passed in is the dataset directory.
# But the metadata.yaml is not in the directory.
os.makedirs(os.path.join(test_dir, "fake_dir"), exist_ok=True)
with pytest.raises(
RuntimeError,
match=r"metadata.yaml does not exist.",
):
_ = gb.OnDiskDataset(fake_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment