Unverified Commit 04752e9f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `force_reload` to OnDiskDataset. (#6930)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 1193e2e8
"""GraphBolt OnDiskDataset."""
import os
import shutil
from copy import deepcopy
from typing import Dict, List, Union
......@@ -34,7 +35,9 @@ __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]
def preprocess_ondisk_dataset(
dataset_dir: str, include_original_edge_id: bool = False
dataset_dir: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> str:
"""Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports.
......@@ -45,6 +48,8 @@ def preprocess_ondisk_dataset(
The path to the dataset directory.
include_original_edge_id : bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
Returns
-------
......@@ -62,13 +67,22 @@ def preprocess_ondisk_dataset(
)
# 0. Check if the dataset is already preprocessed.
preprocess_metadata_path = os.path.join("preprocessed", "metadata.yaml")
processed_dir_prefix = "preprocessed"
preprocess_metadata_path = os.path.join(
processed_dir_prefix, "metadata.yaml"
)
if os.path.exists(os.path.join(dataset_dir, preprocess_metadata_path)):
print("The dataset is already preprocessed.")
return os.path.join(dataset_dir, preprocess_metadata_path)
if force_preprocess:
shutil.rmtree(os.path.join(dataset_dir, processed_dir_prefix))
print(
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed."
)
else:
print("The dataset is already preprocessed.")
return os.path.join(dataset_dir, preprocess_metadata_path)
print("Start to preprocess the on-disk dataset.")
processed_dir_prefix = "preprocessed"
# Check if the metadata.yaml exists.
metadata_file_path = os.path.join(dataset_dir, "metadata.yaml")
......@@ -376,15 +390,22 @@ class OnDiskDataset(Dataset):
The YAML file path.
include_original_edge_id: bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
"""
def __init__(
self, path: str, include_original_edge_id: bool = False
self,
path: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> None:
# Always call the preprocess function first. If already preprocessed,
# the function will return the original path directly.
self._dataset_dir = path
yaml_path = preprocess_ondisk_dataset(path, include_original_edge_id)
yaml_path = preprocess_ondisk_dataset(
path, include_original_edge_id, force_preprocess
)
with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._loaded = False
......
......@@ -1531,6 +1531,79 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
)
def test_OnDiskDataset_preprocess_force_preprocess(capsys):
"""Test force preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# First preprocess on-disk dataset.
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=False
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == [
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "link_prediction"
# Change yaml_data, but do not force preprocess on-disk dataset.
with open(yaml_file, "r") as f:
yaml_data = yaml.safe_load(f)
yaml_data["tasks"][0]["name"] = "fake_name"
with open(yaml_file, "w") as f:
yaml.dump(yaml_data, f)
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=False
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == ["The dataset is already preprocessed.", ""]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "link_prediction"
# Force preprocess on-disk dataset.
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=True
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == [
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed.",
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "fake_name"
@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])
def test_OnDiskDataset_load_name(edge_fmt):
"""Test preprocess of OnDiskDataset."""
......@@ -2182,6 +2255,73 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
dataset = None
def test_OnDiskDataset_force_preprocess(capsys):
"""Test force preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# First preprocess on-disk dataset.
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=False
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == [
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "link_prediction"
# Change yaml_data, but do not force preprocess on-disk dataset.
with open(yaml_file, "r") as f:
yaml_data = yaml.safe_load(f)
yaml_data["tasks"][0]["name"] = "fake_name"
with open(yaml_file, "w") as f:
yaml.dump(yaml_data, f)
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=False
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == ["The dataset is already preprocessed.", ""]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "link_prediction"
# Force preprocess on-disk dataset.
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=True
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == [
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed.",
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "fake_name"
tasks = None
dataset = None
def test_OnDiskTask_repr_homogeneous():
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment