Unverified Commit d90954b1 authored by keli-wen's avatar keli-wen Committed by GitHub
Browse files

[Graphbolt] Add the `preprocess_ondisk_dataset` function. (#5991)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 2746aac3
"""GraphBolt OnDiskDataset.""" """GraphBolt OnDiskDataset."""
import os
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import pandas as pd
import torch
import yaml
import dgl
from ..dataset import Dataset from ..dataset import Dataset
from ..itemset import ItemSet, ItemSetDict from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, tensor_to_tuple from ..utils import read_data, save_data, tensor_to_tuple
from .csc_sampling_graph import CSCSamplingGraph, load_csc_sampling_graph from .csc_sampling_graph import (
CSCSamplingGraph,
from_dglgraph,
load_csc_sampling_graph,
save_csc_sampling_graph,
)
from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet from .ondisk_metadata import OnDiskGraphTopology, OnDiskMetaData, OnDiskTVTSet
from .torch_based_feature_store import ( from .torch_based_feature_store import (
load_feature_stores, load_feature_stores,
...@@ -16,13 +32,184 @@ from .torch_based_feature_store import ( ...@@ -16,13 +32,184 @@ from .torch_based_feature_store import (
__all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"] __all__ = ["OnDiskDataset", "preprocess_ondisk_dataset"]
def preprocess_ondisk_dataset(metadata_path: str) -> str: def preprocess_ondisk_dataset(input_config_path: str) -> str:
"""Preprocess the on-disk dataset.""" """Preprocess the on-disk dataset. Parse the input config file,
# [TODO] load the data, and save the data in the format that GraphBolt supports.
Parameters
----------
input_config_path : str
The path to the input config file.
Returns
-------
output_config_path : str
The path to the output config file.
"""
# 0. Load the input_config.
with open(input_config_path, "r") as f:
input_config = yaml.safe_load(f)
# If the input config does not contain the "graph" field, then we
# assume that the input config is already preprocessed.
if "graph" not in input_config:
print("The input config is already preprocessed.")
return input_config_path
print("Start to preprocess the on-disk dataset.") print("Start to preprocess the on-disk dataset.")
new_metadata_path = metadata_path # Infer the dataset path from the input config path.
dataset_path = Path(os.path.dirname(input_config_path))
processed_dir_prefix = Path("preprocessed")
# 1. Make `processed_dir_prefix` directory if it does not exist.
os.makedirs(dataset_path / processed_dir_prefix, exist_ok=True)
output_config = deepcopy(input_config)
# 2. Load the edge data and create a DGLGraph.
is_homogeneous = "type" not in input_config["graph"]["nodes"][0]
if is_homogeneous:
# Homogeneous graph.
num_nodes = input_config["graph"]["nodes"][0]["num"]
edge_data = pd.read_csv(
dataset_path / input_config["graph"]["edges"][0]["path"],
names=["src", "dst"],
)
src, dst = edge_data["src"].to_numpy(), edge_data["dst"].to_numpy()
g = dgl.graph((src, dst), num_nodes=num_nodes)
else:
# Heterogeneous graph.
# Construct the num nodes dict.
num_nodes_dict = {}
for node_info in input_config["graph"]["nodes"]:
num_nodes_dict[node_info["type"]] = node_info["num"]
# Construct the data dict.
data_dict = {}
for edge_info in input_config["graph"]["edges"]:
edge_data = pd.read_csv(
dataset_path / edge_info["path"], names=["src", "dst"]
)
src = torch.tensor(edge_data["src"])
dst = torch.tensor(edge_data["dst"])
data_dict[tuple(edge_info["type"].split(":"))] = (src, dst)
# Construct the heterograph.
g = dgl.heterograph(data_dict, num_nodes_dict)
# 3. Load the sampling related node/edge features and add them to
# the sampling-graph.
if input_config["graph"].get("feature_data", None):
for graph_feature in input_config["graph"]["feature_data"]:
if graph_feature["domain"] == "node":
node_data = read_data(
dataset_path / graph_feature["path"],
graph_feature["format"],
in_memory=graph_feature["in_memory"],
)
g.ndata[graph_feature["name"]] = node_data
if graph_feature["domain"] == "edge":
edge_data = read_data(
dataset_path / graph_feature["path"],
graph_feature["format"],
in_memory=graph_feature["in_memory"],
)
g.edata[graph_feature["name"]] = edge_data
# 4. Convert the DGLGraph to a CSCSamplingGraph.
csc_sampling_graph = from_dglgraph(g)
# 5. Save the CSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "CSCSamplingGraph"
output_config["graph_topology"]["path"] = (
processed_dir_prefix / "csc_sampling_graph.tar"
)
save_csc_sampling_graph(
csc_sampling_graph,
dataset_path / output_config["graph_topology"]["path"],
)
del output_config["graph"]
# 6. Load the node/edge features and do necessary conversion.
if input_config.get("feature_data", None):
for (feature, out_feature) in zip(
input_config["feature_data"], output_config["feature_data"]
):
# Always save the feature in numpy format.
out_feature["format"] = "numpy"
out_feature["path"] = processed_dir_prefix / feature[
"path"
].replace("pt", "npy")
if feature["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path / os.path.dirname(out_feature["path"]),
exist_ok=True,
)
shutil.copyfile(
dataset_path / feature["path"],
dataset_path / out_feature["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
data = read_data(
dataset_path / feature["path"],
feature["format"],
in_memory=feature["in_memory"],
)
save_data(
data,
dataset_path / out_feature["path"],
out_feature["format"],
)
# 7. Save the train/val/test split according to the output_config.
for set_name in ["train_sets", "validation_sets", "test_sets"]:
if set_name not in input_config:
continue
for intput_set_split, output_set_split in zip(
input_config[set_name], output_config[set_name]
):
for input_set_per_type, output_set_per_type in zip(
intput_set_split, output_set_split
):
# Always save the feature in numpy format.
output_set_per_type["format"] = "numpy"
output_set_per_type[
"path"
] = processed_dir_prefix / input_set_per_type["path"].replace(
"pt", "npy"
)
if input_set_per_type["format"] == "numpy":
# If the original format is numpy, just copy the file.
os.makedirs(
dataset_path
/ os.path.dirname(output_set_per_type["path"]),
exist_ok=True,
)
shutil.copy(
dataset_path / input_set_per_type["path"],
dataset_path / output_set_per_type["path"],
)
else:
# If the original format is not numpy, convert it to numpy.
input_set = read_data(
dataset_path / input_set_per_type["path"],
input_set_per_type["format"],
)
save_data(
input_set,
dataset_path / output_set_per_type["path"],
output_set_per_type["format"],
)
# 8. Save the output_config.
output_config_path = dataset_path / "output_config.yaml"
with open(output_config_path, "w") as f:
yaml.dump(output_config, f)
print("Finish preprocessing the on-disk dataset.") print("Finish preprocessing the on-disk dataset.")
return new_metadata_path return output_config_path
class OnDiskDataset(Dataset): class OnDiskDataset(Dataset):
......
"""Utility functions for GraphBolt.""" """Utility functions for GraphBolt."""
import os
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +26,27 @@ def read_data(path, fmt, in_memory=True): ...@@ -24,6 +26,27 @@ def read_data(path, fmt, in_memory=True):
raise RuntimeError(f"Unsupported format: {fmt}") raise RuntimeError(f"Unsupported format: {fmt}")
def save_data(data, path, fmt):
"""Save data into disk."""
# Make sure the directory exists.
os.makedirs(os.path.dirname(path), exist_ok=True)
if fmt not in ["numpy", "torch"]:
raise RuntimeError(f"Unsupported format: {fmt}")
# Perform necessary conversion.
if fmt == "numpy" and isinstance(element, torch.Tensor):
element = element.cpu().numpy()
elif fmt == "torch" and isinstance(element, np.ndarray):
element = torch.from_numpy(element).cpu()
# Save the data.
if fmt == "numpy":
np.save(path, data)
elif fmt == "torch":
torch.save(data, path)
def tensor_to_tuple(data): def tensor_to_tuple(data):
"""Split a torch.Tensor in column-wise to a tuple.""" """Split a torch.Tensor in column-wise to a tuple."""
assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" assert isinstance(data, torch.Tensor), "data must be a torch.Tensor"
......
...@@ -4,10 +4,12 @@ import tempfile ...@@ -4,10 +4,12 @@ import tempfile
import gb_test_utils as gbt import gb_test_utils as gbt
import numpy as np import numpy as np
import pandas as pd
import pydantic import pydantic
import pytest import pytest
import torch import torch
import yaml
from dgl import graphbolt as gb from dgl import graphbolt as gb
...@@ -747,3 +749,122 @@ def test_OnDiskDataset_Metadata(): ...@@ -747,3 +749,122 @@ def test_OnDiskDataset_Metadata():
assert dataset.dataset_name == dataset_name assert dataset.dataset_name == dataset_name
assert dataset.num_classes is None assert dataset.num_classes is None
assert dataset.num_labels is None assert dataset.num_labels is None
def test_OnDiskDataset_preprocess_homogeneous():
"""Test preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
num_labels = 9
# Generate random edges.
nodes = np.repeat(np.arange(num_nodes), 5)
neighbors = np.random.randint(0, num_nodes, size=(num_edges))
edges = np.stack([nodes, neighbors], axis=1)
# Wrtie into edges/edge.csv
os.makedirs(os.path.join(test_dir, "edges/"), exist_ok=True)
edges = pd.DataFrame(edges, columns=["src", "dst"])
edges.to_csv(
os.path.join(test_dir, "edges/edge.csv"),
index=False,
header=False,
)
# Generate random graph edge-feats.
edge_feats = np.random.rand(num_edges, 5)
os.makedirs(os.path.join(test_dir, "data/"), exist_ok=True)
np.save(os.path.join(test_dir, "data/edge-feat.npy"), edge_feats)
# Generate random node-feats.
node_feats = np.random.rand(num_nodes, 10)
np.save(os.path.join(test_dir, "data/node-feat.npy"), node_feats)
# Generate train/test/valid set.
os.makedirs(os.path.join(test_dir, "set/"), exist_ok=True)
train_pairs = (np.arange(1000), np.arange(1000, 2000))
train_labels = np.random.randint(0, 10, size=1000)
train_data = np.vstack([train_pairs, train_labels]).T
train_path = os.path.join(test_dir, "set/train.npy")
np.save(train_path, train_data)
validation_pairs = (np.arange(1000, 2000), np.arange(2000, 3000))
validation_labels = np.random.randint(0, 10, size=1000)
validation_data = np.vstack([validation_pairs, validation_labels]).T
validation_path = os.path.join(test_dir, "set/validation.npy")
np.save(validation_path, validation_data)
test_pairs = (np.arange(2000, 3000), np.arange(3000, 4000))
test_labels = np.random.randint(0, 10, size=1000)
test_data = np.vstack([test_pairs, test_labels]).T
test_path = os.path.join(test_dir, "set/test.npy")
np.save(test_path, test_data)
yaml_content = f"""
dataset_name: {dataset_name}
num_classes: {num_classes}
num_labels: {num_labels}
graph: # graph structure and required attributes.
nodes:
- num: {num_nodes}
edges:
- format: csv
path: edges/edge.csv
feature_data:
- domain: edge
type: null
name: feat
format: numpy
in_memory: true
path: data/edge-feat.npy
feature_data:
- domain: node
type: null
name: feat
format: numpy
in_memory: false
path: data/node-feat.npy
train_sets:
- - type_name: null
# shape: (num_trains, 3), 3 for (src, dst, label).
format: numpy
path: set/train.npy
validation_sets:
- - type_name: null
format: numpy
path: set/validation.npy
test_sets:
- - type_name: null
format: numpy
path: set/test.npy
"""
yaml_file = os.path.join(test_dir, "test.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(yaml_file)
with open(output_file, "rb") as f:
processed_dataset = yaml.safe_load(f)
assert processed_dataset["dataset_name"] == dataset_name
assert processed_dataset["num_classes"] == num_classes
assert processed_dataset["num_labels"] == num_labels
assert "graph" not in processed_dataset
assert "graph_topology" in processed_dataset
csc_sampling_graph = gb.csc_sampling_graph.load_csc_sampling_graph(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
)
assert csc_sampling_graph.num_nodes == num_nodes
assert csc_sampling_graph.num_edges == num_edges
num_samples = 100
fanout = 1
subgraph = csc_sampling_graph.sample_neighbors(
torch.arange(num_samples),
torch.tensor([fanout]),
)
assert len(list(subgraph.node_pairs.values())[0][0]) <= num_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment