Unverified Commit 04752e9f authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `force_reload` to OnDiskDataset. (#6930)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 1193e2e8
"""GraphBolt OnDiskDataset.""" """GraphBolt OnDiskDataset."""
import os import os
import shutil
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Union from typing import Dict, List, Union
...@@ -34,7 +35,9 @@ __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"] ...@@ -34,7 +35,9 @@ __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset", "BuiltinDataset"]
def preprocess_ondisk_dataset( def preprocess_ondisk_dataset(
dataset_dir: str, include_original_edge_id: bool = False dataset_dir: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> str: ) -> str:
"""Preprocess the on-disk dataset. Parse the input config file, """Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports. load the data, and save the data in the format that GraphBolt supports.
...@@ -45,6 +48,8 @@ def preprocess_ondisk_dataset( ...@@ -45,6 +48,8 @@ def preprocess_ondisk_dataset(
The path to the dataset directory. The path to the dataset directory.
include_original_edge_id : bool, optional include_original_edge_id : bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
Returns Returns
------- -------
...@@ -62,13 +67,22 @@ def preprocess_ondisk_dataset( ...@@ -62,13 +67,22 @@ def preprocess_ondisk_dataset(
) )
# 0. Check if the dataset is already preprocessed. # 0. Check if the dataset is already preprocessed.
preprocess_metadata_path = os.path.join("preprocessed", "metadata.yaml") processed_dir_prefix = "preprocessed"
preprocess_metadata_path = os.path.join(
processed_dir_prefix, "metadata.yaml"
)
if os.path.exists(os.path.join(dataset_dir, preprocess_metadata_path)): if os.path.exists(os.path.join(dataset_dir, preprocess_metadata_path)):
if force_preprocess:
shutil.rmtree(os.path.join(dataset_dir, processed_dir_prefix))
print(
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed."
)
else:
print("The dataset is already preprocessed.") print("The dataset is already preprocessed.")
return os.path.join(dataset_dir, preprocess_metadata_path) return os.path.join(dataset_dir, preprocess_metadata_path)
print("Start to preprocess the on-disk dataset.") print("Start to preprocess the on-disk dataset.")
processed_dir_prefix = "preprocessed"
# Check if the metadata.yaml exists. # Check if the metadata.yaml exists.
metadata_file_path = os.path.join(dataset_dir, "metadata.yaml") metadata_file_path = os.path.join(dataset_dir, "metadata.yaml")
...@@ -376,15 +390,22 @@ class OnDiskDataset(Dataset): ...@@ -376,15 +390,22 @@ class OnDiskDataset(Dataset):
The YAML file path. The YAML file path.
include_original_edge_id: bool, optional include_original_edge_id: bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph. Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
""" """
def __init__( def __init__(
self, path: str, include_original_edge_id: bool = False self,
path: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> None: ) -> None:
# Always call the preprocess function first. If already preprocessed, # Always call the preprocess function first. If already preprocessed,
# the function will return the original path directly. # the function will return the original path directly.
self._dataset_dir = path self._dataset_dir = path
yaml_path = preprocess_ondisk_dataset(path, include_original_edge_id) yaml_path = preprocess_ondisk_dataset(
path, include_original_edge_id, force_preprocess
)
with open(yaml_path) as f: with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._loaded = False self._loaded = False
......
...@@ -1531,6 +1531,79 @@ def test_OnDiskDataset_preprocess_yaml_content_windows(): ...@@ -1531,6 +1531,79 @@ def test_OnDiskDataset_preprocess_yaml_content_windows():
) )
def test_OnDiskDataset_preprocess_force_preprocess(capsys):
"""Test force preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# First preprocess on-disk dataset.
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=False
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == [
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "link_prediction"
# Change yaml_data, but do not force preprocess on-disk dataset.
with open(yaml_file, "r") as f:
yaml_data = yaml.safe_load(f)
yaml_data["tasks"][0]["name"] = "fake_name"
with open(yaml_file, "w") as f:
yaml.dump(yaml_data, f)
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=False
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == ["The dataset is already preprocessed.", ""]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "link_prediction"
# Force preprocess on-disk dataset.
preprocessed_metadata_path = (
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False, force_preprocess=True
)
)
captured = capsys.readouterr().out.split("\n")
assert captured == [
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed.",
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
with open(preprocessed_metadata_path, "r") as f:
target_yaml_data = yaml.safe_load(f)
assert target_yaml_data["tasks"][0]["name"] == "fake_name"
@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"]) @pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])
def test_OnDiskDataset_load_name(edge_fmt): def test_OnDiskDataset_load_name(edge_fmt):
"""Test preprocess of OnDiskDataset.""" """Test preprocess of OnDiskDataset."""
...@@ -2182,6 +2255,73 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt): ...@@ -2182,6 +2255,73 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id, edge_fmt):
dataset = None dataset = None
def test_OnDiskDataset_force_preprocess(capsys):
"""Test force preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# First preprocess on-disk dataset.
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=False
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == [
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "link_prediction"
# Change yaml_data, but do not force preprocess on-disk dataset.
with open(yaml_file, "r") as f:
yaml_data = yaml.safe_load(f)
yaml_data["tasks"][0]["name"] = "fake_name"
with open(yaml_file, "w") as f:
yaml.dump(yaml_data, f)
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=False
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == ["The dataset is already preprocessed.", ""]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "link_prediction"
# Force preprocess on-disk dataset.
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False, force_preprocess=True
).load()
captured = capsys.readouterr().out.split("\n")
assert captured == [
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed.",
"Start to preprocess the on-disk dataset.",
"Finish preprocessing the on-disk dataset.",
"",
]
tasks = dataset.tasks
assert tasks[0].metadata["name"] == "fake_name"
tasks = None
dataset = None
def test_OnDiskTask_repr_homogeneous(): def test_OnDiskTask_repr_homogeneous():
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment