Unverified Commit 0742b85b authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add opt to load tasks selectively. (#6905)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent a697e791
...@@ -411,14 +411,54 @@ class OnDiskDataset(Dataset): ...@@ -411,14 +411,54 @@ class OnDiskDataset(Dataset):
self._dataset_dir, data["path"] self._dataset_dir, data["path"]
) )
def load(self): def load(self, tasks: List[str] = None):
"""Load the dataset.""" """Load the dataset.
Parameters
----------
tasks: List[str] = None
The name of the tasks to be loaded. For single task, the type of
tasks can be both string and List[str]. For multiple tasks, only
List[str] is acceptable.
Examples
--------
1. Loading via single task name "node_classification".
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks="node_classification")
>>> len(dataset.tasks)
1
>>> dataset.tasks[0].metadata["name"]
"node_classification"
2. Loading via single task name ["node_classification"].
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks=["node_classification"])
>>> len(dataset.tasks)
1
>>> dataset.tasks[0].metadata["name"]
"node_classification"
3. Loading via multiple task names ["node_classification",
"link_prediction"].
>>> dataset = gb.OnDiskDataset(base_dir).load(
... tasks=["node_classification","link_prediction"])
>>> len(dataset.tasks)
2
>>> dataset.tasks[0].metadata["name"]
"node_classification"
>>> dataset.tasks[1].metadata["name"]
"link_prediction"
"""
self._convert_yaml_path_to_absolute_path() self._convert_yaml_path_to_absolute_path()
self._meta = OnDiskMetaData(**self._yaml_data) self._meta = OnDiskMetaData(**self._yaml_data)
self._dataset_name = self._meta.dataset_name self._dataset_name = self._meta.dataset_name
self._graph = self._load_graph(self._meta.graph_topology) self._graph = self._load_graph(self._meta.graph_topology)
self._feature = TorchBasedFeatureStore(self._meta.feature_data) self._feature = TorchBasedFeatureStore(self._meta.feature_data)
self._tasks = self._init_tasks(self._meta.tasks) self._tasks = self._init_tasks(self._meta.tasks, tasks)
self._all_nodes_set = self._init_all_nodes_set(self._graph) self._all_nodes_set = self._init_all_nodes_set(self._graph)
self._loaded = True self._loaded = True
return self return self
...@@ -458,12 +498,23 @@ class OnDiskDataset(Dataset): ...@@ -458,12 +498,23 @@ class OnDiskDataset(Dataset):
self._check_loaded() self._check_loaded()
return self._all_nodes_set return self._all_nodes_set
def _init_tasks(self, tasks: List[OnDiskTaskData]) -> List[OnDiskTask]: def _init_tasks(
self, tasks: List[OnDiskTaskData], selected_tasks: List[str]
) -> List[OnDiskTask]:
"""Initialize the tasks.""" """Initialize the tasks."""
if isinstance(selected_tasks, str):
selected_tasks = [selected_tasks]
if selected_tasks and not isinstance(selected_tasks, list):
raise TypeError(
f"The type of selected_task should be list, but got {type(selected_tasks)}"
)
ret = [] ret = []
if tasks is None: if tasks is None:
return ret return ret
task_names = set()
for task in tasks: for task in tasks:
task_name = task.extra_fields.get("name", None)
if selected_tasks is None or task_name in selected_tasks:
ret.append( ret.append(
OnDiskTask( OnDiskTask(
task.extra_fields, task.extra_fields,
...@@ -472,6 +523,14 @@ class OnDiskDataset(Dataset): ...@@ -472,6 +523,14 @@ class OnDiskDataset(Dataset):
self._init_tvt_set(task.test_set), self._init_tvt_set(task.test_set),
) )
) )
if selected_tasks:
task_names.add(task_name)
if selected_tasks:
not_found_tasks = set(selected_tasks) - task_names
if len(not_found_tasks):
dgl_warning(
f"Below tasks are not found in YAML: {not_found_tasks}. Skipped."
)
return ret return ret
def _check_loaded(self): def _check_loaded(self):
......
...@@ -12,9 +12,10 @@ import pydantic ...@@ -12,9 +12,10 @@ import pydantic
import pytest import pytest
import torch import torch
import yaml import yaml
from dgl import graphbolt as gb from dgl import graphbolt as gb
from dgl.base import DGLWarning
from .. import gb_test_utils as gbt from .. import gb_test_utils as gbt
...@@ -2239,3 +2240,61 @@ def test_OnDiskTask_repr_heterogeneous(): ...@@ -2239,3 +2240,61 @@ def test_OnDiskTask_repr_heterogeneous():
)""" )"""
) )
assert str(task) == expected_str, print(task) assert str(task) == expected_str, print(task)
def test_OnDiskDataset_load_tasks_selectively():
"""Test preprocess of OnDiskDataset."""
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
dataset_name = "graphbolt_test"
num_nodes = 4000
num_edges = 20000
num_classes = 10
# Generate random graph.
yaml_content = gbt.random_homo_graphbolt_graph(
test_dir,
dataset_name,
num_nodes,
num_edges,
num_classes,
)
train_path = os.path.join("set", "train.npy")
yaml_content += f""" - name: node_classification
num_classes: {num_classes}
train_set:
- type: null
data:
- format: numpy
path: {train_path}
"""
yaml_file = os.path.join(test_dir, "metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
# Case1. Test load all tasks.
dataset = gb.OnDiskDataset(test_dir).load()
assert len(dataset.tasks) == 2
# Case2. Test load tasks selectively.
dataset = gb.OnDiskDataset(test_dir).load(tasks="link_prediction")
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "link_prediction"
dataset = gb.OnDiskDataset(test_dir).load(tasks=["link_prediction"])
assert len(dataset.tasks) == 1
assert dataset.tasks[0].metadata["name"] == "link_prediction"
# Case3. Test load tasks with non-existent task name.
with pytest.warns(
DGLWarning,
match="Below tasks are not found in YAML: {'fake-name'}. Skipped.",
):
dataset = gb.OnDiskDataset(test_dir).load(tasks=["fake-name"])
assert len(dataset.tasks) == 0
# Case4. Test load tasks selectively with incorrect task type.
with pytest.raises(TypeError):
dataset = gb.OnDiskDataset(test_dir).load(tasks=2)
dataset = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment