Unverified Commit de1eedc6 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add check about whether edge IDs are saved when edge feature is stored. (#6948)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 905321f8
...@@ -219,6 +219,7 @@ def preprocess_ondisk_dataset( ...@@ -219,6 +219,7 @@ def preprocess_ondisk_dataset(
# 7. Load the node/edge features and do necessary conversion. # 7. Load the node/edge features and do necessary conversion.
if input_config.get("feature_data", None): if input_config.get("feature_data", None):
has_edge_feature_data = False
for feature, out_feature in zip( for feature, out_feature in zip(
input_config["feature_data"], output_config["feature_data"] input_config["feature_data"], output_config["feature_data"]
): ):
...@@ -230,6 +231,8 @@ def preprocess_ondisk_dataset( ...@@ -230,6 +231,8 @@ def preprocess_ondisk_dataset(
in_memory = ( in_memory = (
True if "in_memory" not in feature else feature["in_memory"] True if "in_memory" not in feature else feature["in_memory"]
) )
if not has_edge_feature_data and feature["domain"] == "edge":
has_edge_feature_data = True
copy_or_convert_data( copy_or_convert_data(
os.path.join(dataset_dir, feature["path"]), os.path.join(dataset_dir, feature["path"]),
os.path.join(dataset_dir, out_feature["path"]), os.path.join(dataset_dir, out_feature["path"]),
...@@ -238,6 +241,8 @@ def preprocess_ondisk_dataset( ...@@ -238,6 +241,8 @@ def preprocess_ondisk_dataset(
in_memory=in_memory, in_memory=in_memory,
is_feature=True, is_feature=True,
) )
if has_edge_feature_data and not include_original_edge_id:
dgl_warning("Edge feature is stored, but edge IDs are not saved.")
# 8. Save tasks and train/val/test split according to the output_config. # 8. Save tasks and train/val/test split according to the output_config.
if input_config.get("tasks", None): if input_config.get("tasks", None):
......
...@@ -1726,6 +1726,35 @@ def test_OnDiskDataset_preprocess_auto_force_preprocess(capsys): ...@@ -1726,6 +1726,35 @@ def test_OnDiskDataset_preprocess_auto_force_preprocess(capsys):
assert captured == ["The dataset is already preprocessed.", ""] assert captured == ["The dataset is already preprocessed.", ""]
def test_OnDiskDataset_preprocess_not_include_eids():
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
with pytest.warns(
DGLWarning,
match="Edge feature is stored, but edge IDs are not saved.",
):
gb.ondisk_dataset.preprocess_ondisk_dataset(
test_dir, include_original_edge_id=False
)
@pytest.mark.parametrize("edge_fmt", ["csv", "numpy"]) @pytest.mark.parametrize("edge_fmt", ["csv", "numpy"])
def test_OnDiskDataset_load_name(edge_fmt): def test_OnDiskDataset_load_name(edge_fmt):
"""Test preprocess of OnDiskDataset.""" """Test preprocess of OnDiskDataset."""
...@@ -2586,6 +2615,33 @@ def test_OnDiskTask_repr_homogeneous(): ...@@ -2586,6 +2615,33 @@ def test_OnDiskTask_repr_homogeneous():
assert repr(task) == expected_str, task assert repr(task) == expected_str, task
def test_OnDiskDataset_not_include_eids():
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
with pytest.warns(
DGLWarning,
match="Edge feature is stored, but edge IDs are not saved.",
):
gb.OnDiskDataset(test_dir, include_original_edge_id=False)
def test_OnDiskTask_repr_heterogeneous(): def test_OnDiskTask_repr_heterogeneous():
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment