"docs/vscode:/vscode.git/clone" did not exist on "aaee8ff1f75941dd2be82a180826a052965ce84c"
Unverified Commit 0742b85b authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add opt to load tasks selectively. (#6905)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent a697e791
......@@ -411,14 +411,54 @@ class OnDiskDataset(Dataset):
self._dataset_dir, data["path"]
)
def load(self):
"""Load the dataset."""
def load(self, tasks: List[str] = None):
"""Load the dataset.
Parameters
----------
tasks: List[str] = None
The name of the tasks to be loaded. For single task, the type of
tasks can be both string and List[str]. For multiple tasks, only
List[str] is acceptable.
Examples
--------
1. Loading via single task name "node_classification".
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks="node_classification")
>>> len(dataset.tasks)
1
>>> dataset.tasks[0].metadata["name"]
"node_classification"
2. Loading via single task name ["node_classification"].
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks=["node_classification"])
>>> len(dataset.tasks)
1
>>> dataset.tasks[0].metadata["name"]
"node_classification"
3. Loading via multiple task names ["node_classification",
"link_prediction"].
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks=["node_classification","link_prediction"])
>>> len(dataset.tasks)
2
>>> dataset.tasks[0].metadata["name"]
"node_classification"
>>> dataset.tasks[1].metadata["name"]
"link_prediction"
"""
self._convert_yaml_path_to_absolute_path()
self._meta = OnDiskMetaData(**self._yaml_data)
self._dataset_name = self._meta.dataset_name
self._graph = self._load_graph(self._meta.graph_topology)
self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._tasks = self._init_tasks(self._meta.tasks)
self._tasks = self._init_tasks(self._meta.tasks, tasks)
self._all_nodes_set = self._init_all_nodes_set(self._graph)
self._loaded = True
return self
......@@ -458,12 +498,23 @@ class OnDiskDataset(Dataset):
self._check_loaded()
return self._all_nodes_set
def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]:
def _init_tasks(
self, tasks: List[OnDiskTaskData], selected_tasks: List[str]
) -> List[OnDiskTask]:
"""Initialize the tasks."""
if isinstance(selected_tasks, str):
selected_tasks = [selected_tasks]
if selected_tasks and not isinstance(selected_tasks, list):
raise TypeError(
f"The type of selected_task should be list, but got {type(selected_tasks)}"
)
ret = []
if tasks is None:
return ret
task_names = set()
for task in tasks:
task_name = task.extra_fields.get("name", None)
if selected_tasks is None or task_name in selected_tasks:
ret.append(
OnDiskTask(
task.extra_fields,
......@@ -472,6 +523,14 @@ class OnDiskDataset(Dataset):
self._init_tvt_set(task.test_set),
)
)
if selected_tasks:
task_names.add(task_name)
if selected_tasks:
not_found_tasks = set(selected_tasks) - task_names
if len(not_found_tasks):
dgl_warning(
f"Below tasks are not found in YAML: {not_found_tasks}. Skipped."
)
return ret
def _check_loaded(self):
......
......@@ -12,9 +12,10 @@ import pydantic
import pytest
import torch
import yaml
from dgl import graphbolt as gb
from dgl.base import DGLWarning
from .. import gb_test_utils as gbt
......@@ -2239,3 +2240,61 @@ def test_OnDiskTask_repr_heterogeneous():
)"""
)
assert str(task) == expected_str, print(task)
def test_OnDiskDataset_load_tasks_selectively():
"""Test preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
train_path = os.path.join("set", "train.npy")
yaml_content += f""" - name: node_classification
num_classes: {num_classes}
train_set:
- type: null
data:
- format: numpy
path: {train_path}
"""
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# Case1. Test load all tasks.
dataset = gb.OnDiskDataset(test_dir).load()
assert len(dataset.tasks) == 2
# Case2. Test load tasks selectively.
dataset = gb.OnDiskDataset(test_dir).load(tasks="link_prediction")
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "link_prediction"
dataset = gb.OnDiskDataset(test_dir).load(tasks=["link_prediction"])
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "link_prediction"
# Case3. Test load tasks with non-existent task name.
with pytest.warns(
DGLWarning,
match="Below tasks are not found in YAML: {'fake-name'}. Skipped.",
):
dataset = gb.OnDiskDataset(test_dir).load(tasks=["fake-name"])
assert len(dataset.tasks) == 0
# Case4. Test load tasks selectively with incorrect task type.
with pytest.raises(TypeError):
dataset = gb.OnDiskDataset(test_dir).load(tasks=2)
dataset = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment