Unverified Commit 9111deee authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add testcase for multiple tasks in OnDiskDataset (#6525)

parent 20866b06
...@@ -62,6 +62,77 @@ def test_OnDiskDataset_TVTSet_exceptions(): ...@@ -62,6 +62,77 @@ def test_OnDiskDataset_TVTSet_exceptions():
_ = gb.OnDiskDataset(test_dir).load() _ = gb.OnDiskDataset(test_dir).load()
def test_OnDiskDataset_multiple_tasks():
"""Teset multiple tasks are supported."""
with tempfile.TemporaryDirectory() as test_dir:
train_ids = np.arange(1000)
train_ids_path = os.path.join(test_dir, "train_ids.npy")
np.save(train_ids_path, train_ids)
train_labels = np.random.randint(0, 10, size=1000)
train_labels_path = os.path.join(test_dir, "train_labels.npy")
np.save(train_labels_path, train_labels)
yaml_content = f"""
tasks:
- name: node_classification_1
num_classes: 10
train_set:
- type: null
data:
- name: seed_nodes
format: numpy
in_memory: true
path: {train_ids_path}
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
- format: numpy
in_memory: true
path: {train_labels_path}
- name: node_classification_2
num_classes: 10
train_set:
- type: null
data:
- name: seed_nodes
format: numpy
in_memory: true
path: {train_ids_path}
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
- format: numpy
in_memory: true
path: {train_labels_path}
"""
os.makedirs(os.path.join(test_dir, "preprocessed"), exist_ok=True)
yaml_file = os.path.join(test_dir, "preprocessed/metadata.yaml")
with open(yaml_file, "w") as f:
f.write(yaml_content)
dataset = gb.OnDiskDataset(test_dir).load()
assert len(dataset.tasks) == 2
for task_id in range(2):
assert (
dataset.tasks[task_id].metadata["name"]
== f"node_classification_{task_id + 1}"
)
assert dataset.tasks[task_id].metadata["num_classes"] == 10
# Verify train set.
train_set = dataset.tasks[task_id].train_set
assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet)
for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_nodes", "labels", None)
train_set = None
dataset = None
def test_OnDiskDataset_TVTSet_ItemSet_names(): def test_OnDiskDataset_TVTSet_ItemSet_names():
"""Test TVTSet which returns ItemSet with IDs, labels and corresponding names.""" """Test TVTSet which returns ItemSet with IDs, labels and corresponding names."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment