Unverified Commit a4ccc3e9 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Graphbolt] Add a load option `original_edge_id` for `preprocess` and `from_dglgraph` (#6438)

parent 63bc44da
...@@ -819,7 +819,11 @@ def save_csc_sampling_graph(graph, filename): ...@@ -819,7 +819,11 @@ def save_csc_sampling_graph(graph, filename):
print(f"CSCSamplingGraph has been saved to {filename}.") print(f"CSCSamplingGraph has been saved to {filename}.")
def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph: def from_dglgraph(
g: DGLGraph,
is_homogeneous: bool = False,
include_original_edge_id: bool = False,
) -> CSCSamplingGraph:
"""Convert a DGLGraph to CSCSamplingGraph.""" """Convert a DGLGraph to CSCSamplingGraph."""
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True) homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
# Initialize metadata. # Initialize metadata.
...@@ -838,8 +842,10 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph: ...@@ -838,8 +842,10 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Assign edge type according to the order of CSC matrix. # Assign edge type according to the order of CSC matrix.
type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids] type_per_edge = None if is_homogeneous else homo_g.edata[ETYPE][edge_ids]
edge_attributes = {}
if include_original_edge_id:
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes = {ORIGINAL_EDGE_ID: homo_g.edata[EID][edge_ids]} edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
return CSCSamplingGraph( return CSCSamplingGraph(
torch.ops.graphbolt.from_csc( torch.ops.graphbolt.from_csc(
......
...@@ -52,7 +52,9 @@ def _copy_or_convert_data( ...@@ -52,7 +52,9 @@ def _copy_or_convert_data(
save_data(data, output_path, output_format) save_data(data, output_path, output_format)
def preprocess_ondisk_dataset(dataset_dir: str) -> str: def preprocess_ondisk_dataset(
dataset_dir: str, include_original_edge_id: bool = False
) -> str:
"""Preprocess the on-disk dataset. Parse the input config file, """Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports. load the data, and save the data in the format that GraphBolt supports.
...@@ -153,7 +155,9 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str: ...@@ -153,7 +155,9 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
g.edata[graph_feature["name"]] = edge_data g.edata[graph_feature["name"]] = edge_data
# 4. Convert the DGLGraph to a CSCSamplingGraph. # 4. Convert the DGLGraph to a CSCSamplingGraph.
csc_sampling_graph = from_dglgraph(g, is_homogeneous) csc_sampling_graph = from_dglgraph(
g, is_homogeneous, include_original_edge_id
)
# 5. Save the CSCSamplingGraph and modify the output_config. # 5. Save the CSCSamplingGraph and modify the output_config.
output_config["graph_topology"] = {} output_config["graph_topology"] = {}
...@@ -352,11 +356,13 @@ class OnDiskDataset(Dataset): ...@@ -352,11 +356,13 @@ class OnDiskDataset(Dataset):
The YAML file path. The YAML file path.
""" """
def __init__(self, path: str) -> None: def __init__(
self, path: str, include_original_edge_id: bool = False
) -> None:
# Always call the preprocess function first. If already preprocessed, # Always call the preprocess function first. If already preprocessed,
# the function will return the original path directly. # the function will return the original path directly.
self._dataset_dir = path self._dataset_dir = path
yaml_path = preprocess_ondisk_dataset(path) yaml_path = preprocess_ondisk_dataset(path, include_original_edge_id)
with open(yaml_path) as f: with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
......
...@@ -1294,8 +1294,20 @@ def test_multiprocessing_with_shared_memory(): ...@@ -1294,8 +1294,20 @@ def test_multiprocessing_with_shared_memory():
) )
def test_from_dglgraph_homogeneous(): def test_from_dglgraph_homogeneous():
dgl_g = dgl.rand_graph(1000, 10 * 1000) dgl_g = dgl.rand_graph(1000, 10 * 1000)
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=True)
# Check if the original edge id exist in edge attributes when the
# original_edge_id is set to False.
gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=False, include_original_edge_id=False
)
assert (
gb_g.edge_attributes is None
or gb.ORIGINAL_EDGE_ID not in gb_g.edge_attributes
)
gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=True, include_original_edge_id=True
)
# Get the COO representation of the CSCSamplingGraph. # Get the COO representation of the CSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1] num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
rows = gb_g.indices rows = gb_g.indices
...@@ -1335,7 +1347,19 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1335,7 +1347,19 @@ def test_from_dglgraph_heterogeneous():
), ),
} }
) )
gb_g = gb.from_dglgraph(dgl_g, is_homogeneous=False) # Check if the original edge id exist in edge attributes when the
# original_edge_id is set to False.
gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=False, include_original_edge_id=False
)
assert (
gb_g.edge_attributes is None
or gb.ORIGINAL_EDGE_ID not in gb_g.edge_attributes
)
gb_g = gb.from_dglgraph(
dgl_g, is_homogeneous=False, include_original_edge_id=True
)
# `reverse_node_id` is used to map the node id in CSCSamplingGraph to the # `reverse_node_id` is used to map the node id in CSCSamplingGraph to the
# node id in Hetero-DGLGraph. # node id in Hetero-DGLGraph.
......
...@@ -1072,7 +1072,9 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1072,7 +1072,9 @@ def test_OnDiskDataset_preprocess_homogeneous():
yaml_file = os.path.join(test_dir, "metadata.yaml") yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(test_dir) output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False
)
with open(output_file, "rb") as f: with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader) processed_dataset = yaml.load(f, Loader=yaml.Loader)
...@@ -1087,6 +1089,10 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1087,6 +1089,10 @@ def test_OnDiskDataset_preprocess_homogeneous():
) )
assert csc_sampling_graph.total_num_nodes == num_nodes assert csc_sampling_graph.total_num_nodes == num_nodes
assert csc_sampling_graph.total_num_edges == num_edges assert csc_sampling_graph.total_num_edges == num_edges
assert (
csc_sampling_graph.edge_attributes is None
or gb.ORIGINAL_EDGE_ID not in csc_sampling_graph.edge_attributes
)
num_samples = 100 num_samples = 100
fanout = 1 fanout = 1
...@@ -1096,6 +1102,39 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1096,6 +1102,39 @@ def test_OnDiskDataset_preprocess_homogeneous():
) )
assert len(subgraph.node_pairs[0]) <= num_samples assert len(subgraph.node_pairs[0]) <= num_samples
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# Test do not generate original_edge_id.
output_file = gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False
)
with open(output_file, "rb") as f:
processed_dataset = yaml.load(f, Loader=yaml.Loader)
csc_sampling_graph = gb.csc_sampling_graph.load_csc_sampling_graph(
os.path.join(test_dir, processed_dataset["graph_topology"]["path"])
)
assert (
csc_sampling_graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID not in csc_sampling_graph.edge_attributes
)
csc_sampling_graph = None
def test_OnDiskDataset_preprocess_path(): def test_OnDiskDataset_preprocess_path():
"""Test if the preprocess function can catch the path error.""" """Test if the preprocess function can catch the path error."""
...@@ -1577,9 +1616,14 @@ def test_OnDiskDataset_load_graph(): ...@@ -1577,9 +1616,14 @@ def test_OnDiskDataset_load_graph():
with open(yaml_file, "w") as f: with open(yaml_file, "w") as f:
f.write(yaml_content) f.write(yaml_content)
# Check if the CSCSamplingGraph.edge_attributes loaded. # Check the different original_edge_id option to load edge_attributes.
dataset = gb.OnDiskDataset(test_dir).load() dataset = gb.OnDiskDataset(
assert dataset.graph.edge_attributes is not None test_dir, include_original_edge_id=True
).load()
assert (
dataset.graph.edge_attributes is not None
and gb.ORIGINAL_EDGE_ID in dataset.graph.edge_attributes
)
# Case1. Test modify the `type` field. # Case1. Test modify the `type` field.
dataset = gb.OnDiskDataset(test_dir) dataset = gb.OnDiskDataset(test_dir)
...@@ -1620,6 +1664,35 @@ def test_OnDiskDataset_load_graph(): ...@@ -1620,6 +1664,35 @@ def test_OnDiskDataset_load_graph():
modify_graph = None modify_graph = None
dataset = None dataset = None
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# Test do not generate original_edge_id.
dataset = gb.OnDiskDataset(
test_dir, include_original_edge_id=False
).load()
assert (
dataset.graph.edge_attributes is None
or gb.ORIGINAL_EDGE_ID not in dataset.graph.edge_attributes
)
dataset = None
def test_OnDiskDataset_load_tasks(): def test_OnDiskDataset_load_tasks():
"""Test preprocess of OnDiskDataset.""" """Test preprocess of OnDiskDataset."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment